diff options
author | Nao Pross <np@0hm.ch> | 2024-05-04 18:14:53 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-04 18:14:53 +0200 |
commit | d675f9a66bc0ee2c03360f75890dccfbe72a0fac (patch) | |
tree | 1b05568a7400016b0359310cf41467b6256f891b | |
parent | Fix Expression._binary, avoid creating unnecessary nodes (diff) | |
download | polymatrix-d675f9a66bc0ee2c03360f75890dccfbe72a0fac.tar.gz polymatrix-d675f9a66bc0ee2c03360f75890dccfbe72a0fac.zip |
Fix AdditionExprMixin
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 49 |
1 files changed, 13 insertions, 36 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 02b45a4..17de72c 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.typing import PolyMatrixDict from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix @@ -47,39 +48,15 @@ class AdditionExprMixin(ExpressionBaseMixin): left, right = broadcast_poly_matrix(left, right, self.stack) - poly_matrix_data = {} - - # NP: this code is very old I presume, so it iterates over zero entries - # FIXME: iterate only over non-zero entries - for row in range(left.shape[0]): - for col in range(left.shape[1]): - poly_data = {} - - for underlying in (left, right): - polynomial = underlying.get_poly(row, col) - if polynomial is None: - continue - - if len(poly_data) == 0: - poly_data = dict(polynomial) - - else: - for monomial, value in polynomial.items(): - if monomial not in poly_data: - poly_data[monomial] = value - - else: - poly_data[monomial] += value - - if math.isclose(poly_data[monomial], 0): - del poly_data[monomial] - - if 0 < len(poly_data): - poly_matrix_data[row, col] = poly_data - - poly_matrix = init_poly_matrix( - data=poly_matrix_data, - shape=left.shape, - ) - - return state, poly_matrix + # Keep left as-is and add the stuff from right + result = PolyMatrixDict.empty() + for entry, right_poly in right.entries(): + new_poly = left.at(*entry) + for monomial, coeff in right_poly.terms(): + if monomial in new_poly: + new_poly[monomial] += coeff + else: + new_poly[monomial] = coeff + + result[*entry] = new_poly + return state, init_poly_matrix(result, left.shape) |