summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/mixins/setelementatexprmixin.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py
index 9d552bc..2bb0bfe 100644
--- a/polymatrix/expression/mixins/setelementatexprmixin.py
+++ b/polymatrix/expression/mixins/setelementatexprmixin.py
@@ -39,9 +39,14 @@ class SetElementAtExprMixin(ExpressionBaseMixin):
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, value = self.value.apply(state=state)
+ state, value_expr = self.value.apply(state=state)
- assert value.shape == (1, 1)
+ assert value_expr.shape == (1, 1)
+
+ try:
+ value = value_expr.get_poly(0, 0)
+ except KeyError:
+ value = 0
@dataclass_abc.dataclass_abc(frozen=True)
class SetElementAtPolyMatrix(PolyMatrixMixin):
@@ -49,10 +54,6 @@ class SetElementAtExprMixin(ExpressionBaseMixin):
shape: tuple[int, int]
index: tuple[int, int]
value: dict
-
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.shape
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
if (row, col) == self.index:
@@ -64,6 +65,6 @@ class SetElementAtExprMixin(ExpressionBaseMixin):
underlying=underlying,
index=self.index,
shape=underlying.shape,
- value=value.get_poly(0, 0)
+ value=value,
)
\ No newline at end of file