diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/mixins/elemmultexprmixin.py | 14 | ||||
-rw-r--r-- | polymatrix/polymatrix/init.py | 2 |
2 files changed, 10 insertions, 6 deletions
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index c7fbb7f..e83bf3a 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -27,12 +27,16 @@ class ElemMultExprMixin(ExpressionBaseMixin): left: PolyMatrix, right: PolyMatrix, ): - if left.shape != right.shape and left.shape == (1, 1): - left, right = right, left + if left.shape != right.shape: + if left.shape == (1,1): + left = init_broadcast_poly_matrix(left.scalar(), shape=right.shape) - if right.shape == (1, 1): - right_poly = right.get_poly(0, 0) - right = init_broadcast_poly_matrix(data=right_poly, shape=right.shape) + elif right.shape == (1,1): + right = init_broadcast_poly_matrix(right.scalar(), shape=left.shape) + + else: + raise NotImplementedError("Cannot do element-wise multiplication of matrices " + f"with shapes {left.shape} and {right.shape}.") poly_matrix_data = {} diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index ac6d199..a13017b 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -130,7 +130,7 @@ def init_slice_poly_matrix( raise TypeError("{what} {el} of type {type(el)} is not a valid slice type.") - return SlicePolyMatrixImpl(reference=reference, shape=shape, slice=tuple(formatted_slices)) + return SlicePolyMatrixImpl(reference=reference, shape=tuple(shape), slice=tuple(formatted_slices)) # FIXME: rename to init_affine_expression |