summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/divisionexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/divisionexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
new file mode 100644
index 0000000..01d3505
--- /dev/null
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -0,0 +1,81 @@
+
+import abc
+import dataclasses
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class DivisionExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def left(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.left.shape
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ if self in state.cached_polymatrix:
+ return state, state.cached_polymatrix[self]
+
+ state, left = self.left.apply(state=state)
+ state, right = self.right.apply(state=state)
+
+ assert right.shape == (1, 1)
+
+ terms = {}
+
+ division_variable = state.n_param
+ state = state.register(n_param=1)
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ try:
+ underlying_terms = left.get_poly(row, col)
+ except KeyError:
+ continue
+
+ def gen_monomial_terms():
+ for monomial, value in underlying_terms.items():
+ yield monomial + (division_variable,), value
+
+ terms[row, col] = dict(gen_monomial_terms())
+
+ def gen_auxillary_terms():
+ for monomial, value in right.get_poly(0, 0).items():
+ yield monomial + (division_variable,), value
+
+ auxillary_terms = dict(gen_auxillary_terms())
+
+ if tuple() not in auxillary_terms:
+ auxillary_terms[tuple()] = 0
+
+ auxillary_terms[tuple()] -= 1
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ state = dataclasses.replace(
+ state,
+ auxillary_terms=state.auxillary_terms + (auxillary_terms,),
+ cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
+ )
+
+ return state, poly_matrix