summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-27 15:00:03 +0200
committerNao Pross <np@0hm.ch>2024-05-27 15:00:03 +0200
commit62818491b9e3367a4a59556d0b1349c1d3e61ef6 (patch)
tree695f0f3ee58ed446d1bcfd90288e7644604b4527
parentImprove polymatrix.ones and zeros (diff)
downloadpolymatrix-62818491b9e3367a4a59556d0b1349c1d3e61ef6.tar.gz
polymatrix-62818491b9e3367a4a59556d0b1349c1d3e61ef6.zip
Fix bug with shapes
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py14
-rw-r--r--polymatrix/polymatrix/init.py2
2 files changed, 10 insertions, 6 deletions
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index c7fbb7f..e83bf3a 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -27,12 +27,16 @@ class ElemMultExprMixin(ExpressionBaseMixin):
left: PolyMatrix,
right: PolyMatrix,
):
- if left.shape != right.shape and left.shape == (1, 1):
- left, right = right, left
+ if left.shape != right.shape:
+ if left.shape == (1,1):
+ left = init_broadcast_poly_matrix(left.scalar(), shape=right.shape)
- if right.shape == (1, 1):
- right_poly = right.get_poly(0, 0)
- right = init_broadcast_poly_matrix(data=right_poly, shape=right.shape)
+ elif right.shape == (1,1):
+ right = init_broadcast_poly_matrix(right.scalar(), shape=left.shape)
+
+ else:
+ raise NotImplementedError("Cannot do element-wise multiplication of matrices "
+ f"with shapes {left.shape} and {right.shape}.")
poly_matrix_data = {}
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index ac6d199..a13017b 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -130,7 +130,7 @@ def init_slice_poly_matrix(
raise TypeError("{what} {el} of type {type(el)} is not a valid slice type.")
- return SlicePolyMatrixImpl(reference=reference, shape=shape, slice=tuple(formatted_slices))
+ return SlicePolyMatrixImpl(reference=reference, shape=tuple(shape), slice=tuple(formatted_slices))
# FIXME: rename to init_affine_expression