summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-04 18:14:53 +0200
committerNao Pross <np@0hm.ch>2024-05-04 18:14:53 +0200
commitd675f9a66bc0ee2c03360f75890dccfbe72a0fac (patch)
tree1b05568a7400016b0359310cf41467b6256f891b
parentFix Expression._binary, avoid creating unnecessary nodes (diff)
downloadpolymatrix-d675f9a66bc0ee2c03360f75890dccfbe72a0fac.tar.gz
polymatrix-d675f9a66bc0ee2c03360f75890dccfbe72a0fac.zip
Fix AdditionExprMixin
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py49
1 files changed, 13 insertions, 36 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 02b45a4..17de72c 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -9,6 +9,7 @@ if typing.TYPE_CHECKING:
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.typing import PolyMatrixDict
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix
@@ -47,39 +48,15 @@ class AdditionExprMixin(ExpressionBaseMixin):
left, right = broadcast_poly_matrix(left, right, self.stack)
- poly_matrix_data = {}
-
- # NP: this code is very old I presume, so it iterates over zero entries
- # FIXME: iterate only over non-zero entries
- for row in range(left.shape[0]):
- for col in range(left.shape[1]):
- poly_data = {}
-
- for underlying in (left, right):
- polynomial = underlying.get_poly(row, col)
- if polynomial is None:
- continue
-
- if len(poly_data) == 0:
- poly_data = dict(polynomial)
-
- else:
- for monomial, value in polynomial.items():
- if monomial not in poly_data:
- poly_data[monomial] = value
-
- else:
- poly_data[monomial] += value
-
- if math.isclose(poly_data[monomial], 0):
- del poly_data[monomial]
-
- if 0 < len(poly_data):
- poly_matrix_data[row, col] = poly_data
-
- poly_matrix = init_poly_matrix(
- data=poly_matrix_data,
- shape=left.shape,
- )
-
- return state, poly_matrix
+ # Keep left as-is and add the stuff from right
+ result = PolyMatrixDict.empty()
+ for entry, right_poly in right.entries():
+ new_poly = left.at(*entry)
+ for monomial, coeff in right_poly.terms():
+ if monomial in new_poly:
+ new_poly[monomial] += coeff
+ else:
+ new_poly[monomial] = coeff
+
+ result[*entry] = new_poly
+ return state, init_poly_matrix(result, left.shape)