diff options
Diffstat (limited to 'polymatrix/expression/mixins/divisionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/divisionexprmixin.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py new file mode 100644 index 0000000..01d3505 --- /dev/null +++ b/polymatrix/expression/mixins/divisionexprmixin.py @@ -0,0 +1,81 @@ + +import abc +import dataclasses + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.polymatrixexprstate import PolyMatrixExprState + + +class DivisionExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def left(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractmethod + def right(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + @property + def shape(self) -> tuple[int, int]: + return self.left.shape + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: PolyMatrixExprState, + ) -> tuple[PolyMatrixExprState, PolyMatrix]: + if self in state.cached_polymatrix: + return state, state.cached_polymatrix[self] + + state, left = self.left.apply(state=state) + state, right = self.right.apply(state=state) + + assert right.shape == (1, 1) + + terms = {} + + division_variable = state.n_param + state = state.register(n_param=1) + + for row in range(self.shape[0]): + for col in range(self.shape[1]): + + try: + underlying_terms = left.get_poly(row, col) + except KeyError: + continue + + def gen_monomial_terms(): + for monomial, value in underlying_terms.items(): + yield monomial + (division_variable,), value + + terms[row, col] = dict(gen_monomial_terms()) + + def gen_auxillary_terms(): + for monomial, value in right.get_poly(0, 0).items(): + yield monomial + (division_variable,), value + + auxillary_terms = dict(gen_auxillary_terms()) + + if tuple() not in auxillary_terms: + auxillary_terms[tuple()] = 0 + + auxillary_terms[tuple()] -= 1 + + poly_matrix = init_poly_matrix( + terms=terms, + shape=self.shape, + ) + + state = dataclasses.replace( + state, + auxillary_terms=state.auxillary_terms + (auxillary_terms,), + cached_polymatrix=state.cached_polymatrix | {self: poly_matrix}, + ) + + return state, poly_matrix |