diff options
Diffstat (limited to 'polymatrix/expression/mixins/divisionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/divisionexprmixin.py | 109 |
1 files changed, 0 insertions, 109 deletions
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py deleted file mode 100644 index 6dd2bd6..0000000 --- a/polymatrix/expression/mixins/divisionexprmixin.py +++ /dev/null @@ -1,109 +0,0 @@ -from __future__ import annotations - -import abc -import dataclasses -import typing - -if typing.TYPE_CHECKING: - from polymatrix.expressionstate.abc import ExpressionState - -from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin -from polymatrix.utils.getstacklines import FrameSummary -from polymatrix.utils.tooperatorexception import to_operator_exception -from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.abc import PolyMatrix - - -class DivisionExprMixin(ExpressionBaseMixin): - @property - @abc.abstractmethod - def left(self) -> ExpressionBaseMixin: ... - - @property - @abc.abstractmethod - def right(self) -> ExpressionBaseMixin: ... - - @property - @abc.abstractmethod - def stack(self) -> tuple[FrameSummary]: ... - - # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: - state, left = self.left.apply(state=state) - state, right = self.right.apply(state=state) - - # if left.shape == (1, 1): - # left, right = right, left - - # assert right.shape == (1, 1) - - if not (right.shape == (1, 1)): - raise AssertionError( - to_operator_exception( - message=f"{right.shape=} is not (1, 1)", - stack=self.stack, - ) - ) - - right_poly = right.get_poly(0, 0) - - if len(right_poly) == 1 and tuple() in right_poly: - right_inv = {(0, 0): {tuple(): 1 / right_poly[tuple()]}} - return ElemMultExprMixin.elem_mult( - state=state, - left=left, - right=init_poly_matrix( - data=right_inv, - shape=(1, 1), - ), - ) - - # add an auxillary equation and, therefore, needs to be cached - if self in state.cache: - return state, state.cache[self] - - poly_matrix_data = {} - - division_variable = state.n_param - state = state.register(n_param=1) - - for row in range(left.shape[0]): - for col in range(left.shape[1]): - underlying_poly = left.get_poly(row, col) - if underlying_poly is None: - continue - - def gen_polynomial(): - for monomial, value in underlying_poly.items(): - yield monomial + ((division_variable, 1),), value - - poly_matrix_data[row, col] = dict(gen_polynomial()) - - def gen_auxillary_polynomials(): - for monomial, value in right_poly.items(): - yield monomial + ((division_variable, 1),), value - - auxillary_poly = dict(gen_auxillary_polynomials()) - - if tuple() not in auxillary_poly: - auxillary_poly[tuple()] = 0 - - auxillary_poly[tuple()] -= 1 - - poly_matrix = init_poly_matrix( - data=poly_matrix_data, - shape=left.shape, - ) - - state = dataclasses.replace( - state, - auxillary_equations=state.auxillary_equations - | {division_variable: auxillary_poly}, - cache=state.cache | {self: poly_matrix}, - ) - - return state, poly_matrix |