From b4557d119212abcb4e0df452743a0fe91c97c751 Mon Sep 17 00:00:00 2001 From: Michael Schneeberger Date: Tue, 6 Dec 2022 13:13:32 +0100 Subject: bugfix: elementwise multiplication works with a scalar value on both sides --- polymatrix/expression/mixins/elemmultexprmixin.py | 41 ++--------------------- 1 file changed, 3 insertions(+), 38 deletions(-) diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index 408b8c8..ab1f265 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -36,6 +36,9 @@ class ElemMultExprMixin(ExpressionBaseMixin): state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) + if left.shape != right.shape and left.shape == (1, 1): + left, right = right, left + if right.shape == (1, 1): right_poly = right.get_poly(0, 0) @@ -52,44 +55,6 @@ class ElemMultExprMixin(ExpressionBaseMixin): shape=left.shape, ) - # if right.shape == (1, 1): - - # right_terms = right.get_poly(0, 0) - - # terms = {} - - # for poly_row in range(left.shape[0]): - # for poly_col in range(left.shape[1]): - - # terms_row_col = {} - - # try: - # left_terms = left.get_poly(poly_row, poly_col) - # except KeyError: - # continue - - # for (left_monomial, left_value), (right_monomial, right_value) \ - # in itertools.product(left_terms.items(), right_terms.items()): - - # value = left_value * right_value - - # # if value == 0: - # # continue - - # # monomial = tuple(sorted(left_monomial + right_monomial)) - - # new_monomial = merge_monomial_indices((left_monomial, right_monomial)) - - # if new_monomial not in terms_row_col: - # terms_row_col[new_monomial] = 0 - - # terms_row_col[new_monomial] += value - - # if 0 < len(terms_row_col): - # terms[poly_row, poly_col] = terms_row_col - - # elif left.shape == right.shape: - terms = {} for poly_row in range(left.shape[0]): -- cgit v1.2.1