diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-12-06 13:13:32 +0100 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-12-06 13:13:32 +0100 |
commit | b4557d119212abcb4e0df452743a0fe91c97c751 (patch) | |
tree | 4827fd91934ed08558af1630bb8e1f0fc764c9d4 | |
parent | remove polymatrix entry when two entries add up to zero (diff) | |
download | polymatrix-b4557d119212abcb4e0df452743a0fe91c97c751.tar.gz polymatrix-b4557d119212abcb4e0df452743a0fe91c97c751.zip |
bugfix: elementwise multiplication works with a scalar value on both sides
-rw-r--r-- | polymatrix/expression/mixins/elemmultexprmixin.py | 41 |
1 files changed, 3 insertions, 38 deletions
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index 408b8c8..ab1f265 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -36,6 +36,9 @@ class ElemMultExprMixin(ExpressionBaseMixin): state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) + if left.shape != right.shape and left.shape == (1, 1): + left, right = right, left + if right.shape == (1, 1): right_poly = right.get_poly(0, 0) @@ -52,44 +55,6 @@ class ElemMultExprMixin(ExpressionBaseMixin): shape=left.shape, ) - # if right.shape == (1, 1): - - # right_terms = right.get_poly(0, 0) - - # terms = {} - - # for poly_row in range(left.shape[0]): - # for poly_col in range(left.shape[1]): - - # terms_row_col = {} - - # try: - # left_terms = left.get_poly(poly_row, poly_col) - # except KeyError: - # continue - - # for (left_monomial, left_value), (right_monomial, right_value) \ - # in itertools.product(left_terms.items(), right_terms.items()): - - # value = left_value * right_value - - # # if value == 0: - # # continue - - # # monomial = tuple(sorted(left_monomial + right_monomial)) - - # new_monomial = merge_monomial_indices((left_monomial, right_monomial)) - - # if new_monomial not in terms_row_col: - # terms_row_col[new_monomial] = 0 - - # terms_row_col[new_monomial] += value - - # if 0 < len(terms_row_col): - # terms[poly_row, poly_col] = terms_row_col - - # elif left.shape == right.shape: - terms = {} for poly_row in range(left.shape[0]): |