diff options
author | Nao Pross <np@0hm.ch> | 2024-05-04 19:04:03 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-04 19:07:11 +0200 |
commit | 3f038881247094b8a13a8b6417007703655896dd (patch) | |
tree | a142d5c986c9ca546e680c674f228f064326ce50 | |
parent | Reintroduce PolyMatrixDict.__getitem__ to make the typechecker shut up (diff) | |
download | polymatrix-3f038881247094b8a13a8b6417007703655896dd.tar.gz polymatrix-3f038881247094b8a13a8b6417007703655896dd.zip |
Fix v_stack, add __str__ representation to some operation mixins
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/__init__.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/expression.py | 10 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 29 | ||||
-rw-r--r-- | polymatrix/expression/mixins/eyeexprmixin.py | 1 | ||||
-rw-r--r-- | polymatrix/expression/mixins/symmetricexprmixin.py | 5 | ||||
-rw-r--r-- | polymatrix/expression/mixins/vstackexprmixin.py | 5 |
6 files changed, 45 insertions, 7 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 6f7eae2..b150cc5 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -15,7 +15,7 @@ def v_stack( def gen_underlying(): for expr in expressions: if isinstance(expr, Expression): - yield expr + yield expr.underlying else: yield polymatrix.expression.from_.from_(expr) diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 3c52297..fbe7ac3 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -493,12 +493,12 @@ class Expression(ExpressionBaseMixin, ABC): ) -@dataclassabc(frozen=True, repr=False) +@dataclassabc(frozen=True) class ExpressionImpl(Expression): underlying: ExpressionBaseMixin - def __repr__(self) -> str: - return self.underlying.__repr__() + def __str__(self): + return f"Expression({self.underlying})" def copy(self, underlying: ExpressionBaseMixin) -> Expression: return dataclasses.replace( @@ -539,8 +539,8 @@ class VariableExpression(Expression, Variable): class VariableExpressionImpl(VariableExpression): underlying: ExpressionBaseMixin - def __repr__(self) -> str: - return self.underlying.__repr__() + def __str__(self): + return f"VariableExpression({self.underlying})" def copy(self, underlying: ExpressionBaseMixin) -> Expression: return init_expression(underlying) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 119d8c7..600ad81 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -83,6 +83,9 @@ class AdditionExprImpl(AdditionExprMixin): def __repr__(self): return f"{self.__class__.__name__}(left={self.left}, right={self.right})" + def __str__(self): + return f"({self.left} + {self.right})" + @dataclassabc.dataclassabc(frozen=True) class BlockDiagExprImpl(BlockDiagExprMixin): @@ -128,6 +131,9 @@ class ElemMultExprImpl(ElemMultExprMixin): left: ExpressionBaseMixin right: ExpressionBaseMixin + def __str__(self): + return f"({self.left} * {self.right})" + @dataclassabc.dataclassabc(frozen=True) class EvalExprImpl(EvalExprMixin): @@ -151,6 +157,11 @@ class FilterExprImpl(FilterExprMixin): class FromNumbersExprImpl(FromNumbersExprMixin): data: tuple[tuple[int | float]] + def __str__(self): + if len(self.data) == 1: + if len(self.data[0]) == 1: + return str(self.data[0][0]) + @dataclassabc.dataclassabc(frozen=True) class FromNumpyExprImpl(FromNumpyExprMixin): @@ -237,9 +248,13 @@ class MatrixMultExprImpl(MatrixMultExprMixin): stack: tuple[FrameSummary] # implement custom __repr__ method that returns a representation without the stack + # FIXME: remove this? def __repr__(self): return f"{self.__class__.__name__}(left={self.left}, right={self.right})" + def __str__(self): + return f"({self.left} @ {self.right})" + @dataclassabc.dataclassabc(frozen=True) class DegreeExprImpl(DegreeExprMixin): @@ -367,6 +382,9 @@ class ToSymmetricMatrixExprImpl(ToSymmetricMatrixExprMixin): class TransposeExprImpl(TransposeExprMixin): underlying: ExpressionBaseMixin + def __str__(self): + return f"({self.underlying}).T" + @dataclassabc.dataclassabc(frozen=True) class TruncateExprImpl(TruncateExprMixin): @@ -381,7 +399,16 @@ class VariableImpl(VariableMixin): name: str shape: tuple[int, int] + def __str__(self): + return self.name + @dataclassabc.dataclassabc(frozen=True) class VStackExprImpl(VStackExprMixin): - underlying: tuple + underlying: tuple[ExpressionBaseMixin, ...] + + def __str__(self): + inner = ", ".join(map(str, self.underlying)) + return f"v_stack({inner})" + + diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py index e3b28c2..b18c89b 100644 --- a/polymatrix/expression/mixins/eyeexprmixin.py +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -39,6 +39,7 @@ class EyeExprMixin(ExpressionBaseMixin): else: raise Exception(f"{(row, col)=} is out of bounds") + # FIXME: this behaviour is counterintuitive, eye should take just a number for the dimension n_row = variable.shape[0] polymatrix = EyePolyMatrix( diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py index 8bc398f..f585240 100644 --- a/polymatrix/expression/mixins/symmetricexprmixin.py +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: from polymatrix.expressionstate.abc import ExpressionState from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.typing import PolyDict from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -41,6 +42,10 @@ class SymmetricExprMixin(ExpressionBaseMixin): def shape(self) -> tuple[int, int]: return self.underlying.shape + def at(self, row: int, col: int) -> PolyDict: + # FIXME: this is a quick workaround + return PolyDict(self.get_poly(row, col) or {}) + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: def gen_symmetric_monomials(): for i_row, i_col in ((row, col), (col, row)): diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index ab23ee6..2091f88 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -8,6 +8,7 @@ if typing.TYPE_CHECKING: from polymatrix.expressionstate.abc import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.polymatrix.typing import PolyDict from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix @@ -45,6 +46,10 @@ class VStackExprMixin(ExpressionBaseMixin): underlying_row_range: tuple[tuple[int, int], ...] shape: tuple[int, int] + def at(self, row: int, col: int) -> PolyDict: + # FIXME: this is a quick workaround + return self.get_poly(row, col) or PolyDict.empty() + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: for polymatrix, (row_start, row_end) in zip( self.all_underlying, self.underlying_row_range |