diff options
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py new file mode 100644 index 0000000..776e952 --- /dev/null +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -0,0 +1,70 @@ + +import abc +import typing +import dataclass_abc + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.polymatrixexprstate import PolyMatrixExprState + + +class AddExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def left(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractmethod + def right(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + @property + def shape(self) -> tuple[int, int]: + return self.left.shape + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: PolyMatrixExprState, + ) -> tuple[PolyMatrixExprState, PolyMatrix]: + state, left = self.left.apply(state=state) + state, right = self.right.apply(state=state) + + assert left.shape == right.shape + + terms = {} + + for underlying in (left, right): + + for row in range(self.shape[0]): + for col in range(self.shape[1]): + + if (row, col) in terms: + terms_row_col = terms[row, col] + + else: + terms_row_col = {} + + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue + + for monomial, value in underlying_terms.items(): + + if monomial not in terms_row_col: + terms_row_col[monomial] = 0 + + terms_row_col[monomial] += value + + terms[row, col] = terms_row_col + + poly_matrix = init_poly_matrix( + terms=terms, + shape=self.shape, + ) + + return state, poly_matrix |