summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-12-06 13:13:32 +0100
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-12-06 13:13:32 +0100
commitb4557d119212abcb4e0df452743a0fe91c97c751 (patch)
tree4827fd91934ed08558af1630bb8e1f0fc764c9d4
parentremove polymatrix entry when two entries add up to zero (diff)
downloadpolymatrix-b4557d119212abcb4e0df452743a0fe91c97c751.tar.gz
polymatrix-b4557d119212abcb4e0df452743a0fe91c97c751.zip
bugfix: elementwise multiplication works with a scalar value on both sides
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py41
1 files changed, 3 insertions, 38 deletions
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index 408b8c8..ab1f265 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -36,6 +36,9 @@ class ElemMultExprMixin(ExpressionBaseMixin):
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
+ if left.shape != right.shape and left.shape == (1, 1):
+ left, right = right, left
+
if right.shape == (1, 1):
right_poly = right.get_poly(0, 0)
@@ -52,44 +55,6 @@ class ElemMultExprMixin(ExpressionBaseMixin):
shape=left.shape,
)
- # if right.shape == (1, 1):
-
- # right_terms = right.get_poly(0, 0)
-
- # terms = {}
-
- # for poly_row in range(left.shape[0]):
- # for poly_col in range(left.shape[1]):
-
- # terms_row_col = {}
-
- # try:
- # left_terms = left.get_poly(poly_row, poly_col)
- # except KeyError:
- # continue
-
- # for (left_monomial, left_value), (right_monomial, right_value) \
- # in itertools.product(left_terms.items(), right_terms.items()):
-
- # value = left_value * right_value
-
- # # if value == 0:
- # # continue
-
- # # monomial = tuple(sorted(left_monomial + right_monomial))
-
- # new_monomial = merge_monomial_indices((left_monomial, right_monomial))
-
- # if new_monomial not in terms_row_col:
- # terms_row_col[new_monomial] = 0
-
- # terms_row_col[new_monomial] += value
-
- # if 0 < len(terms_row_col):
- # terms[poly_row, poly_col] = terms_row_col
-
- # elif left.shape == right.shape:
-
terms = {}
for poly_row in range(left.shape[0]):