summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/mixins/diagexprmixin.py46
1 files changed, 32 insertions, 14 deletions
diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py
index 3cf8573..7eded9e 100644
--- a/polymatrix/expression/mixins/diagexprmixin.py
+++ b/polymatrix/expression/mixins/diagexprmixin.py
@@ -19,17 +19,35 @@ class DiagExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state)
- assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}'
-
- @dataclass_abc.dataclass_abc(frozen=True)
- class TracePolyMatrix(PolyMatrixMixin):
- underlying: PolyMatrixMixin
- shape: tuple[int, int]
-
- def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]:
- return self.underlying.get_poly(row, row)
-
- return state, TracePolyMatrix(
- underlying=underlying,
- shape=(underlying.shape[0], 1),
- ) \ No newline at end of file
+ if underlying.shape[1] == 1:
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class DiagPolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ if row == col:
+ return self.underlying.get_poly(row, 0)
+ else:
+ return {tuple(): 0.0}
+
+ return state, DiagPolyMatrix(
+ underlying=underlying,
+ shape=(underlying.shape[0], underlying.shape[0]),
+ )
+
+ else:
+ assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}'
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class TracePolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]:
+ return self.underlying.get_poly(row, row)
+
+ return state, TracePolyMatrix(
+ underlying=underlying,
+ shape=(underlying.shape[0], 1),
+ ) \ No newline at end of file