diff options
-rw-r--r-- | polymatrix/expression/mixins/setelementatexprmixin.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py index 9d552bc..2bb0bfe 100644 --- a/polymatrix/expression/mixins/setelementatexprmixin.py +++ b/polymatrix/expression/mixins/setelementatexprmixin.py @@ -39,9 +39,14 @@ class SetElementAtExprMixin(ExpressionBaseMixin): state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, value = self.value.apply(state=state) + state, value_expr = self.value.apply(state=state) - assert value.shape == (1, 1) + assert value_expr.shape == (1, 1) + + try: + value = value_expr.get_poly(0, 0) + except KeyError: + value = 0 @dataclass_abc.dataclass_abc(frozen=True) class SetElementAtPolyMatrix(PolyMatrixMixin): @@ -49,10 +54,6 @@ class SetElementAtExprMixin(ExpressionBaseMixin): shape: tuple[int, int] index: tuple[int, int] value: dict - - # @property - # def shape(self) -> tuple[int, int]: - # return self.shape def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: if (row, col) == self.index: @@ -64,6 +65,6 @@ class SetElementAtExprMixin(ExpressionBaseMixin): underlying=underlying, index=self.index, shape=underlying.shape, - value=value.get_poly(0, 0) + value=value, )
\ No newline at end of file |