diff options
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 49 |
1 files changed, 23 insertions, 26 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 3d2d15b..c4f3113 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -1,5 +1,6 @@ import abc +import collections import typing import dataclass_abc @@ -29,8 +30,6 @@ class AdditionExprMixin(ExpressionBaseMixin): state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - terms = {} - if left.shape == (1, 1): left, right = right, left @@ -44,15 +43,12 @@ class AdditionExprMixin(ExpressionBaseMixin): def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: return self.underlying_monomials - try: - underlying_terms = right.get_poly(0, 0) + polynomial = right.get_poly(0, 0) - except KeyError: - pass + if polynomial is not None: - else: broadcasted_right = BroadCastedPolyMatrix( - underlying_monomials=underlying_terms, + underlying_monomials=polynomial, shape=left.shape, ) @@ -63,30 +59,31 @@ class AdditionExprMixin(ExpressionBaseMixin): all_underlying = (left, right) - for underlying in all_underlying: + terms = {} - for row in range(left.shape[0]): - for col in range(left.shape[1]): - - if (row, col) in terms: - terms_row_col = terms[row, col] + for row in range(left.shape[0]): + for col in range(left.shape[1]): - else: - terms_row_col = {} + terms_row_col = {} - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: - continue + for underlying in all_underlying: - for monomial, value in underlying_terms.items(): - - if monomial not in terms_row_col: - terms_row_col[monomial] = 0 + polynomial = underlying.get_poly(row, col) + if polynomial is None: + continue - terms_row_col[monomial] += value + if len(terms_row_col) == 0: + terms_row_col = dict(polynomial) - terms[row, col] = terms_row_col + else: + for monomial, value in polynomial.items(): + + if monomial not in terms_row_col: + terms_row_col[monomial] = value + else: + terms_row_col[monomial] += value + + terms[row, col] = terms_row_col poly_matrix = init_poly_matrix( terms=terms, |