summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/additionexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
new file mode 100644
index 0000000..776e952
--- /dev/null
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -0,0 +1,70 @@
+
+import abc
+import typing
+import dataclass_abc
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class AddExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def left(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.left.shape
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, left = self.left.apply(state=state)
+ state, right = self.right.apply(state=state)
+
+ assert left.shape == right.shape
+
+ terms = {}
+
+ for underlying in (left, right):
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ if (row, col) in terms:
+ terms_row_col = terms[row, col]
+
+ else:
+ terms_row_col = {}
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ for monomial, value in underlying_terms.items():
+
+ if monomial not in terms_row_col:
+ terms_row_col[monomial] = 0
+
+ terms_row_col[monomial] += value
+
+ terms[row, col] = terms_row_col
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix