From 62818491b9e3367a4a59556d0b1349c1d3e61ef6 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 27 May 2024 15:00:03 +0200 Subject: Fix bug with shapes --- polymatrix/expression/mixins/elemmultexprmixin.py | 14 +++++++++----- polymatrix/polymatrix/init.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index c7fbb7f..e83bf3a 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -27,12 +27,16 @@ class ElemMultExprMixin(ExpressionBaseMixin): left: PolyMatrix, right: PolyMatrix, ): - if left.shape != right.shape and left.shape == (1, 1): - left, right = right, left + if left.shape != right.shape: + if left.shape == (1,1): + left = init_broadcast_poly_matrix(left.scalar(), shape=right.shape) - if right.shape == (1, 1): - right_poly = right.get_poly(0, 0) - right = init_broadcast_poly_matrix(data=right_poly, shape=right.shape) + elif right.shape == (1,1): + right = init_broadcast_poly_matrix(right.scalar(), shape=left.shape) + + else: + raise NotImplementedError("Cannot do element-wise multiplication of matrices " + f"with shapes {left.shape} and {right.shape}.") poly_matrix_data = {} diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index ac6d199..a13017b 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -130,7 +130,7 @@ def init_slice_poly_matrix( raise TypeError("{what} {el} of type {type(el)} is not a valid slice type.") - return SlicePolyMatrixImpl(reference=reference, shape=shape, slice=tuple(formatted_slices)) + return SlicePolyMatrixImpl(reference=reference, shape=tuple(shape), slice=tuple(formatted_slices)) # FIXME: rename to init_affine_expression -- cgit v1.2.1