summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py49
1 files changed, 13 insertions, 36 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 02b45a4..17de72c 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -9,6 +9,7 @@ if typing.TYPE_CHECKING:
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.typing import PolyMatrixDict
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix
@@ -47,39 +48,15 @@ class AdditionExprMixin(ExpressionBaseMixin):
left, right = broadcast_poly_matrix(left, right, self.stack)
- poly_matrix_data = {}
-
- # NP: this code is very old I presume, so it iterates over zero entries
- # FIXME: iterate only over non-zero entries
- for row in range(left.shape[0]):
- for col in range(left.shape[1]):
- poly_data = {}
-
- for underlying in (left, right):
- polynomial = underlying.get_poly(row, col)
- if polynomial is None:
- continue
-
- if len(poly_data) == 0:
- poly_data = dict(polynomial)
-
- else:
- for monomial, value in polynomial.items():
- if monomial not in poly_data:
- poly_data[monomial] = value
-
- else:
- poly_data[monomial] += value
-
- if math.isclose(poly_data[monomial], 0):
- del poly_data[monomial]
-
- if 0 < len(poly_data):
- poly_matrix_data[row, col] = poly_data
-
- poly_matrix = init_poly_matrix(
- data=poly_matrix_data,
- shape=left.shape,
- )
-
- return state, poly_matrix
+ # Keep left as-is and add the stuff from right
+ result = PolyMatrixDict.empty()
+ for entry, right_poly in right.entries():
+ new_poly = left.at(*entry)
+ for monomial, coeff in right_poly.terms():
+ if monomial in new_poly:
+ new_poly[monomial] += coeff
+ else:
+ new_poly[monomial] = coeff
+
+ result[*entry] = new_poly
+ return state, init_poly_matrix(result, left.shape)