diff options
author | Nao Pross <np@0hm.ch> | 2024-05-27 15:00:03 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-27 15:00:03 +0200 |
commit | 62818491b9e3367a4a59556d0b1349c1d3e61ef6 (patch) | |
tree | 695f0f3ee58ed446d1bcfd90288e7644604b4527 | |
parent | Improve polymatrix.ones and zeros (diff) | |
download | polymatrix-62818491b9e3367a4a59556d0b1349c1d3e61ef6.tar.gz polymatrix-62818491b9e3367a4a59556d0b1349c1d3e61ef6.zip |
Fix bug with shapes
-rw-r--r-- | polymatrix/expression/mixins/elemmultexprmixin.py | 14 | ||||
-rw-r--r-- | polymatrix/polymatrix/init.py | 2 |
2 files changed, 10 insertions, 6 deletions
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index c7fbb7f..e83bf3a 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -27,12 +27,16 @@ class ElemMultExprMixin(ExpressionBaseMixin): left: PolyMatrix, right: PolyMatrix, ): - if left.shape != right.shape and left.shape == (1, 1): - left, right = right, left + if left.shape != right.shape: + if left.shape == (1,1): + left = init_broadcast_poly_matrix(left.scalar(), shape=right.shape) - if right.shape == (1, 1): - right_poly = right.get_poly(0, 0) - right = init_broadcast_poly_matrix(data=right_poly, shape=right.shape) + elif right.shape == (1,1): + right = init_broadcast_poly_matrix(right.scalar(), shape=left.shape) + + else: + raise NotImplementedError("Cannot do element-wise multiplication of matrices " + f"with shapes {left.shape} and {right.shape}.") poly_matrix_data = {} diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index ac6d199..a13017b 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -130,7 +130,7 @@ def init_slice_poly_matrix( raise TypeError("{what} {el} of type {type(el)} is not a valid slice type.") - return SlicePolyMatrixImpl(reference=reference, shape=shape, slice=tuple(formatted_slices)) + return SlicePolyMatrixImpl(reference=reference, shape=tuple(shape), slice=tuple(formatted_slices)) # FIXME: rename to init_affine_expression |