From b4557d119212abcb4e0df452743a0fe91c97c751 Mon Sep 17 00:00:00 2001
From: Michael Schneeberger <michael.schneeberger@fhnw.ch>
Date: Tue, 6 Dec 2022 13:13:32 +0100
Subject: bugfix: elementwise multiplication works with a scalar value on both
 sides

---
 polymatrix/expression/mixins/elemmultexprmixin.py | 41 ++---------------------
 1 file changed, 3 insertions(+), 38 deletions(-)

diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index 408b8c8..ab1f265 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -36,6 +36,9 @@ class ElemMultExprMixin(ExpressionBaseMixin):
         state, left = self.left.apply(state=state)
         state, right = self.right.apply(state=state)
 
+        if left.shape != right.shape and left.shape == (1, 1):
+            left, right = right, left
+
         if right.shape == (1, 1):
             right_poly = right.get_poly(0, 0)
 
@@ -52,44 +55,6 @@ class ElemMultExprMixin(ExpressionBaseMixin):
                 shape=left.shape,
             )
 
-        # if right.shape == (1, 1):
-
-        #     right_terms = right.get_poly(0, 0)
-
-        #     terms = {}
-
-        #     for poly_row in range(left.shape[0]):
-        #         for poly_col in range(left.shape[1]):
-
-        #             terms_row_col = {}
-
-        #             try:
-        #                 left_terms = left.get_poly(poly_row, poly_col)
-        #             except KeyError:
-        #                 continue
-
-        #             for (left_monomial, left_value), (right_monomial, right_value) \
-        #                     in itertools.product(left_terms.items(), right_terms.items()):
-
-        #                 value = left_value * right_value
-
-        #                 # if value == 0:
-        #                 #     continue
-
-        #                 # monomial = tuple(sorted(left_monomial + right_monomial))
-
-        #                 new_monomial = merge_monomial_indices((left_monomial, right_monomial))
-
-        #                 if new_monomial not in terms_row_col:
-        #                     terms_row_col[new_monomial] = 0
-
-        #                 terms_row_col[new_monomial] += value
-
-        #             if 0 < len(terms_row_col):
-        #                 terms[poly_row, poly_col] = terms_row_col
-
-        # elif left.shape == right.shape:
-
         terms = {}
 
         for poly_row in range(left.shape[0]):
-- 
cgit v1.2.1