summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/additionexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py49
1 files changed, 23 insertions, 26 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 3d2d15b..c4f3113 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -1,5 +1,6 @@
import abc
+import collections
import typing
import dataclass_abc
@@ -29,8 +30,6 @@ class AdditionExprMixin(ExpressionBaseMixin):
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- terms = {}
-
if left.shape == (1, 1):
left, right = right, left
@@ -44,15 +43,12 @@ class AdditionExprMixin(ExpressionBaseMixin):
def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
return self.underlying_monomials
- try:
- underlying_terms = right.get_poly(0, 0)
+ polynomial = right.get_poly(0, 0)
- except KeyError:
- pass
+ if polynomial is not None:
- else:
broadcasted_right = BroadCastedPolyMatrix(
- underlying_monomials=underlying_terms,
+ underlying_monomials=polynomial,
shape=left.shape,
)
@@ -63,30 +59,31 @@ class AdditionExprMixin(ExpressionBaseMixin):
all_underlying = (left, right)
- for underlying in all_underlying:
+ terms = {}
- for row in range(left.shape[0]):
- for col in range(left.shape[1]):
-
- if (row, col) in terms:
- terms_row_col = terms[row, col]
+ for row in range(left.shape[0]):
+ for col in range(left.shape[1]):
- else:
- terms_row_col = {}
+ terms_row_col = {}
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
- continue
+ for underlying in all_underlying:
- for monomial, value in underlying_terms.items():
-
- if monomial not in terms_row_col:
- terms_row_col[monomial] = 0
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
+ continue
- terms_row_col[monomial] += value
+ if len(terms_row_col) == 0:
+ terms_row_col = dict(polynomial)
- terms[row, col] = terms_row_col
+ else:
+ for monomial, value in polynomial.items():
+
+ if monomial not in terms_row_col:
+ terms_row_col[monomial] = value
+ else:
+ terms_row_col[monomial] += value
+
+ terms[row, col] = terms_row_col
poly_matrix = init_poly_matrix(
terms=terms,