From e8341a6367db6e98ed3ebd6a732544fc52c006e8 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 4 May 2024 23:07:04 +0200 Subject: Adapt nested polymatrix classes to use new API --- polymatrix/expression/mixins/blockdiagexprmixin.py | 6 +++- polymatrix/expression/mixins/diagexprmixin.py | 32 ++++++++++++---------- polymatrix/expression/mixins/eyeexprmixin.py | 17 ++++++------ polymatrix/expression/mixins/fromsympyexprmixin.py | 2 +- polymatrix/expression/mixins/getitemexprmixin.py | 6 ++++ polymatrix/expression/mixins/repmatexprmixin.py | 6 ++-- polymatrix/expression/mixins/reshapeexprmixin.py | 5 ++-- .../expression/mixins/setelementatexprmixin.py | 8 ++++-- polymatrix/expression/mixins/symmetricexprmixin.py | 1 + polymatrix/expression/mixins/transposeexprmixin.py | 7 +++-- polymatrix/expression/mixins/vstackexprmixin.py | 7 ++--- 11 files changed, 58 insertions(+), 39 deletions(-) diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py index 77edf2b..420372f 100644 --- a/polymatrix/expression/mixins/blockdiagexprmixin.py +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -10,6 +10,7 @@ if typing.TYPE_CHECKING: from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.typing import PolyDict from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -34,13 +35,16 @@ class BlockDiagExprMixin(ExpressionBaseMixin): state, polymat = expr.apply(state=state) all_underlying.append(polymat) - # NP: this is a very weird place to put a class + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class BlockDiagPolyMatrix(PolyMatrixMixin): all_underlying: tuple[PolyMatrixMixin] underlying_row_col_range: tuple[tuple[int, int], ...] shape: tuple[int, int] + def at(self, row: int, col: int) -> PolyDict: + return self.get_poly(row, col) or PolyDict.empty() + # FIXME: typing problems def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: for polymatrix, ((row_start, col_start), (row_end, col_end)) in zip( diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py index 66d5a6c..9d019e4 100644 --- a/polymatrix/expression/mixins/diagexprmixin.py +++ b/polymatrix/expression/mixins/diagexprmixin.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.polymatrix.typing import PolyDict class DiagExprMixin(ExpressionBaseMixin): @@ -32,40 +33,41 @@ class DiagExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) + # Vector to diagonal matrix if underlying.shape[1] == 1: + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) - class DiagPolyMatrix(PolyMatrixMixin): + class DiagFromVecPolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin shape: tuple[int, int] - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: - if row == col: - return self.underlying.get_poly(row, 0) - else: - # FIXME: should return none according to base class - # NP: Though returning zero makes more sense - return {tuple(): 0.0} + def at(self, row: int, col: int) -> PolyDict: + if row != col: + return PolyDict.empty() - return state, DiagPolyMatrix( + return self.underlying.get_poly(row, 0) + + return state, DiagFromVecPolyMatrix( underlying=underlying, shape=(underlying.shape[0], underlying.shape[0]), ) + # Diagonal matrix to vector else: - # NP: replace assertions with meaningful exception + # FIXME: replace assertions with meaningful error message assert underlying.shape[0] == underlying.shape[1], f"{underlying.shape=}" - # NP: why is this called Trace? + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) - class TracePolyMatrix(PolyMatrixMixin): + class VecFromDiagPolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin shape: tuple[int, int] - def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]: - return self.underlying.get_poly(row, row) + def at(self, row: int, _col: int) -> PolyDict: + return self.underlying.at(row, row) - return state, TracePolyMatrix( + return state, VecFromDiagPolyMatrix( underlying=underlying, shape=(underlying.shape[0], 1), ) diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py index b18c89b..8b61018 100644 --- a/polymatrix/expression/mixins/eyeexprmixin.py +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -10,6 +10,7 @@ if typing.TYPE_CHECKING: from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.typing import PolyDict, MonomialIndex class EyeExprMixin(ExpressionBaseMixin): @@ -24,20 +25,20 @@ class EyeExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionState, PolyMatrix]: state, variable = self.variable.apply(state) + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class EyePolyMatrix(PolyMatrixMixin): shape: tuple[int, int] - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: - if max(row, col) <= self.shape[0]: - if row == col: - return {tuple(): 1.0} + def at(self, row: int, col: int) -> PolyDict: + size, _ = self.shape + if max(row, col) > size: + raise IndexError(f"Identity matrix has size {size}, {row, col} is out of bounds.") - else: - return None + if row != col: + return PolyDict.empty() - else: - raise Exception(f"{(row, col)=} is out of bounds") + return PolyDict({MonomialIndex.constant(): 1.}) # FIXME: this behaviour is counterintuitive, eye should take just a number for the dimension n_row = variable.shape[0] diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index 784ba00..3a896e1 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -75,7 +75,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): raise ValueError(f"Cannot convert sympy expression {entry} " "into a polynomial, are you sure it is a polynomial?") from e - # Convert sympy variables to our variables, i.e VariableMixin + # Convert sympy variables to our variables sympy_to_var = { sympy_idx: init_variable(var.name, shape=(1,1)) for sympy_idx, var in enumerate(sympy_poly.gens) diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py index f8fece4..81cb439 100644 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ b/polymatrix/expression/mixins/getitemexprmixin.py @@ -11,6 +11,7 @@ if typing.TYPE_CHECKING: from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.typing import PolyDict class GetItemExprMixin(ExpressionBaseMixin): @@ -65,6 +66,7 @@ class GetItemExprMixin(ExpressionBaseMixin): get_proper_index(self.index[1], underlying.shape[1]), ) + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class GetItemPolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin @@ -74,6 +76,10 @@ class GetItemExprMixin(ExpressionBaseMixin): def shape(self) -> tuple[int, int]: return (len(self.index[0]), len(self.index[1])) + def at(self, row: int, col: int) -> PolyDict: + # FIXME: this is a quick fix + return self.get_poly(self, row, col) or PolyDict.empty() + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: try: n_row = self.index[0][row] diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py index 324a298..29e0553 100644 --- a/polymatrix/expression/mixins/repmatexprmixin.py +++ b/polymatrix/expression/mixins/repmatexprmixin.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.polymatrix.typing import PolyDict class RepMatExprMixin(ExpressionBaseMixin): @@ -27,18 +28,19 @@ class RepMatExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class RepMatPolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin shape: tuple[int, int] - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + def at(self, row: int, col: int) -> PolyDict: n_row, n_col = underlying.shape rel_row = row % n_row rel_col = col % n_col - return self.underlying.get_poly(rel_row, rel_col) + return self.underlying.at(rel_row, rel_col) return state, RepMatPolyMatrix( underlying=underlying, diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py index 84c2581..4c1140e 100644 --- a/polymatrix/expression/mixins/reshapeexprmixin.py +++ b/polymatrix/expression/mixins/reshapeexprmixin.py @@ -12,6 +12,7 @@ if typing.TYPE_CHECKING: from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.polymatrix.typing import PolyDict class ReshapeExprMixin(ExpressionBaseMixin): @@ -38,13 +39,13 @@ class ReshapeExprMixin(ExpressionBaseMixin): shape: tuple[int, int] underlying_shape: tuple[int, int] - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + def at(self, row: int, col: int) -> PolyDict: index = row + self.shape[0] * col underlying_col = int(index / self.underlying_shape[0]) underlying_row = index - underlying_col * self.underlying_shape[0] - return self.underlying.get_poly(underlying_row, underlying_col) + return self.underlying.at(underlying_row, underlying_col) # replace expression by their number of rows def acc_new_shape(acc, index): diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py index 9757a28..35568b4 100644 --- a/polymatrix/expression/mixins/setelementatexprmixin.py +++ b/polymatrix/expression/mixins/setelementatexprmixin.py @@ -12,6 +12,7 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.typing import PolyDict class SetElementAtExprMixin(ExpressionBaseMixin): @@ -47,18 +48,19 @@ class SetElementAtExprMixin(ExpressionBaseMixin): if polynomial is None: polynomial = 0 + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class SetElementAtPolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin shape: tuple[int, int] index: tuple[int, int] - polynomial: dict + polynomial: PolyDict - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + def at(self, row: int, col: int) -> PolyDict: if (row, col) == self.index: return self.polynomial else: - return self.underlying.get_poly(row, col) + return self.underlying.at(row, col) return state, SetElementAtPolyMatrix( underlying=underlying, diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py index f585240..f608139 100644 --- a/polymatrix/expression/mixins/symmetricexprmixin.py +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -34,6 +34,7 @@ class SymmetricExprMixin(ExpressionBaseMixin): assert underlying.shape[0] == underlying.shape[1] + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class SymmetricPolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py index fbd5f30..0e076d3 100644 --- a/polymatrix/expression/mixins/transposeexprmixin.py +++ b/polymatrix/expression/mixins/transposeexprmixin.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.polymatrix.typing import PolyDict from polymatrix.polymatrix.abc import PolyMatrix @@ -30,13 +31,15 @@ class TransposeExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class TransposePolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin shape: tuple[int, int] - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: - return self.underlying.get_poly(col, row) + def at(self, row: int, col: int) -> PolyDict: + return self.underlying.at(col, row) + return state, TransposePolyMatrix( underlying=underlying, diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index 2091f88..dc9339a 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -40,6 +40,7 @@ class VStackExprMixin(ExpressionBaseMixin): underlying.shape[1] == all_underlying[0].shape[1] ), f"{underlying.shape[1]} not equal {all_underlying[0].shape[1]}" + # FIXME: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class VStackPolyMatrix(PolyMatrixMixin): all_underlying: tuple[PolyMatrixMixin] @@ -47,15 +48,11 @@ class VStackExprMixin(ExpressionBaseMixin): shape: tuple[int, int] def at(self, row: int, col: int) -> PolyDict: - # FIXME: this is a quick workaround - return self.get_poly(row, col) or PolyDict.empty() - - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: for polymatrix, (row_start, row_end) in zip( self.all_underlying, self.underlying_row_range ): if row_start <= row < row_end: - return polymatrix.get_poly( + return polymatrix.at( row=row - row_start, col=col, ) -- cgit v1.2.1