summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-03-16 16:41:39 +0100
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-03-16 16:41:39 +0100
commitc475212fd6ed0c209b74e3967b0d43ed4ac083fa (patch)
tree7c77b9fb3857b317a7fdfa21861acf8189cf69ad
parentmake sorting optional for to_matrix_repr (diff)
downloadpolymatrix-c475212fd6ed0c209b74e3967b0d43ed4ac083fa.tar.gz
polymatrix-c475212fd6ed0c209b74e3967b0d43ed4ac083fa.zip
diag operation on a vector produces a sqaure matrix
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/mixins/diagexprmixin.py46
1 files changed, 32 insertions, 14 deletions
diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py
index 3cf8573..7eded9e 100644
--- a/polymatrix/expression/mixins/diagexprmixin.py
+++ b/polymatrix/expression/mixins/diagexprmixin.py
@@ -19,17 +19,35 @@ class DiagExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state)
- assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}'
-
- @dataclass_abc.dataclass_abc(frozen=True)
- class TracePolyMatrix(PolyMatrixMixin):
- underlying: PolyMatrixMixin
- shape: tuple[int, int]
-
- def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]:
- return self.underlying.get_poly(row, row)
-
- return state, TracePolyMatrix(
- underlying=underlying,
- shape=(underlying.shape[0], 1),
- ) \ No newline at end of file
+ if underlying.shape[1] == 1:
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class DiagPolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ if row == col:
+ return self.underlying.get_poly(row, 0)
+ else:
+ return {tuple(): 0.0}
+
+ return state, DiagPolyMatrix(
+ underlying=underlying,
+ shape=(underlying.shape[0], underlying.shape[0]),
+ )
+
+ else:
+ assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}'
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class TracePolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]:
+ return self.underlying.get_poly(row, row)
+
+ return state, TracePolyMatrix(
+ underlying=underlying,
+ shape=(underlying.shape[0], 1),
+ ) \ No newline at end of file