diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-03-16 16:41:39 +0100 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-03-16 16:41:39 +0100 |
commit | c475212fd6ed0c209b74e3967b0d43ed4ac083fa (patch) | |
tree | 7c77b9fb3857b317a7fdfa21861acf8189cf69ad | |
parent | make sorting optional for to_matrix_repr (diff) | |
download | polymatrix-c475212fd6ed0c209b74e3967b0d43ed4ac083fa.tar.gz polymatrix-c475212fd6ed0c209b74e3967b0d43ed4ac083fa.zip |
diag operation on a vector produces a sqaure matrix
-rw-r--r-- | polymatrix/expression/mixins/diagexprmixin.py | 46 |
1 files changed, 32 insertions, 14 deletions
diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py index 3cf8573..7eded9e 100644 --- a/polymatrix/expression/mixins/diagexprmixin.py +++ b/polymatrix/expression/mixins/diagexprmixin.py @@ -19,17 +19,35 @@ class DiagExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state) - assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}' - - @dataclass_abc.dataclass_abc(frozen=True) - class TracePolyMatrix(PolyMatrixMixin): - underlying: PolyMatrixMixin - shape: tuple[int, int] - - def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]: - return self.underlying.get_poly(row, row) - - return state, TracePolyMatrix( - underlying=underlying, - shape=(underlying.shape[0], 1), - )
\ No newline at end of file + if underlying.shape[1] == 1: + @dataclass_abc.dataclass_abc(frozen=True) + class DiagPolyMatrix(PolyMatrixMixin): + underlying: PolyMatrixMixin + shape: tuple[int, int] + + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + if row == col: + return self.underlying.get_poly(row, 0) + else: + return {tuple(): 0.0} + + return state, DiagPolyMatrix( + underlying=underlying, + shape=(underlying.shape[0], underlying.shape[0]), + ) + + else: + assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}' + + @dataclass_abc.dataclass_abc(frozen=True) + class TracePolyMatrix(PolyMatrixMixin): + underlying: PolyMatrixMixin + shape: tuple[int, int] + + def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]: + return self.underlying.get_poly(row, row) + + return state, TracePolyMatrix( + underlying=underlying, + shape=(underlying.shape[0], 1), + )
\ No newline at end of file |