From d675f9a66bc0ee2c03360f75890dccfbe72a0fac Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Sat, 4 May 2024 18:14:53 +0200
Subject: Fix AdditionExprMixin

---
 polymatrix/expression/mixins/additionexprmixin.py | 49 ++++++-----------------
 1 file changed, 13 insertions(+), 36 deletions(-)

diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 02b45a4..17de72c 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -9,6 +9,7 @@ if typing.TYPE_CHECKING:
 
 from polymatrix.utils.getstacklines import FrameSummary
 from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.typing import PolyMatrixDict
 from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix
@@ -47,39 +48,15 @@ class AdditionExprMixin(ExpressionBaseMixin):
 
         left, right = broadcast_poly_matrix(left, right, self.stack)
 
-        poly_matrix_data = {}
-
-        # NP: this code is very old I presume, so it iterates over zero entries
-        # FIXME: iterate only over non-zero entries
-        for row in range(left.shape[0]):
-            for col in range(left.shape[1]):
-                poly_data = {}
-
-                for underlying in (left, right):
-                    polynomial = underlying.get_poly(row, col)
-                    if polynomial is None:
-                        continue
-
-                    if len(poly_data) == 0:
-                        poly_data = dict(polynomial)
-
-                    else:
-                        for monomial, value in polynomial.items():
-                            if monomial not in poly_data:
-                                poly_data[monomial] = value
-
-                            else:
-                                poly_data[monomial] += value
-
-                                if math.isclose(poly_data[monomial], 0):
-                                    del poly_data[monomial]
-
-                if 0 < len(poly_data):
-                    poly_matrix_data[row, col] = poly_data
-
-        poly_matrix = init_poly_matrix(
-            data=poly_matrix_data,
-            shape=left.shape,
-        )
-
-        return state, poly_matrix
+        # Keep left as-is and add the stuff from right
+        result = PolyMatrixDict.empty()
+        for entry, right_poly in right.entries():
+            new_poly = left.at(*entry)
+            for monomial, coeff in right_poly.terms():
+                if monomial in new_poly:
+                    new_poly[monomial] += coeff
+                else:
+                    new_poly[monomial] = coeff
+
+            result[*entry] = new_poly
+        return state, init_poly_matrix(result, left.shape)
-- 
cgit v1.2.1