From d675f9a66bc0ee2c03360f75890dccfbe72a0fac Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 4 May 2024 18:14:53 +0200 Subject: Fix AdditionExprMixin --- polymatrix/expression/mixins/additionexprmixin.py | 49 ++++++----------------- 1 file changed, 13 insertions(+), 36 deletions(-) diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 02b45a4..17de72c 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.typing import PolyMatrixDict from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix @@ -47,39 +48,15 @@ class AdditionExprMixin(ExpressionBaseMixin): left, right = broadcast_poly_matrix(left, right, self.stack) - poly_matrix_data = {} - - # NP: this code is very old I presume, so it iterates over zero entries - # FIXME: iterate only over non-zero entries - for row in range(left.shape[0]): - for col in range(left.shape[1]): - poly_data = {} - - for underlying in (left, right): - polynomial = underlying.get_poly(row, col) - if polynomial is None: - continue - - if len(poly_data) == 0: - poly_data = dict(polynomial) - - else: - for monomial, value in polynomial.items(): - if monomial not in poly_data: - poly_data[monomial] = value - - else: - poly_data[monomial] += value - - if math.isclose(poly_data[monomial], 0): - del poly_data[monomial] - - if 0 < len(poly_data): - poly_matrix_data[row, col] = poly_data - - poly_matrix = init_poly_matrix( - data=poly_matrix_data, - shape=left.shape, - ) - - return state, poly_matrix + # Keep left as-is and add the stuff from right + result = PolyMatrixDict.empty() + for entry, right_poly in right.entries(): + new_poly = left.at(*entry) + for monomial, coeff in right_poly.terms(): + if monomial in new_poly: + new_poly[monomial] += coeff + else: + new_poly[monomial] = coeff + + result[*entry] = new_poly + return state, init_poly_matrix(result, left.shape) -- cgit v1.2.1