summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/divisionexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/divisionexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
index 01d3505..de5f685 100644
--- a/polymatrix/expression/mixins/divisionexprmixin.py
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -19,16 +19,17 @@ class DivisionExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.left.shape
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
self,
state: PolyMatrixExprState,
) -> tuple[PolyMatrixExprState, PolyMatrix]:
+
if self in state.cached_polymatrix:
return state, state.cached_polymatrix[self]
@@ -42,8 +43,8 @@ class DivisionExprMixin(ExpressionBaseMixin):
division_variable = state.n_param
state = state.register(n_param=1)
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
+ for row in range(left.shape[0]):
+ for col in range(left.shape[1]):
try:
underlying_terms = left.get_poly(row, col)
@@ -69,12 +70,12 @@ class DivisionExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=left.shape,
)
state = dataclasses.replace(
state,
- auxillary_terms=state.auxillary_terms + (auxillary_terms,),
+ auxillary_equations=state.auxillary_equations | {division_variable: auxillary_terms},
cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
)