From 3f038881247094b8a13a8b6417007703655896dd Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 4 May 2024 19:04:03 +0200 Subject: Fix v_stack, add __str__ representation to some operation mixins --- polymatrix/expression/__init__.py | 2 +- polymatrix/expression/expression.py | 10 ++++---- polymatrix/expression/impl.py | 29 +++++++++++++++++++++- polymatrix/expression/mixins/eyeexprmixin.py | 1 + polymatrix/expression/mixins/symmetricexprmixin.py | 5 ++++ polymatrix/expression/mixins/vstackexprmixin.py | 5 ++++ 6 files changed, 45 insertions(+), 7 deletions(-) diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 6f7eae2..b150cc5 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -15,7 +15,7 @@ def v_stack( def gen_underlying(): for expr in expressions: if isinstance(expr, Expression): - yield expr + yield expr.underlying else: yield polymatrix.expression.from_.from_(expr) diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 3c52297..fbe7ac3 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -493,12 +493,12 @@ class Expression(ExpressionBaseMixin, ABC): ) -@dataclassabc(frozen=True, repr=False) +@dataclassabc(frozen=True) class ExpressionImpl(Expression): underlying: ExpressionBaseMixin - def __repr__(self) -> str: - return self.underlying.__repr__() + def __str__(self): + return f"Expression({self.underlying})" def copy(self, underlying: ExpressionBaseMixin) -> Expression: return dataclasses.replace( @@ -539,8 +539,8 @@ class VariableExpression(Expression, Variable): class VariableExpressionImpl(VariableExpression): underlying: ExpressionBaseMixin - def __repr__(self) -> str: - return self.underlying.__repr__() + def __str__(self): + return f"VariableExpression({self.underlying})" def copy(self, underlying: ExpressionBaseMixin) -> Expression: return init_expression(underlying) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 119d8c7..600ad81 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -83,6 +83,9 @@ class AdditionExprImpl(AdditionExprMixin): def __repr__(self): return f"{self.__class__.__name__}(left={self.left}, right={self.right})" + def __str__(self): + return f"({self.left} + {self.right})" + @dataclassabc.dataclassabc(frozen=True) class BlockDiagExprImpl(BlockDiagExprMixin): @@ -128,6 +131,9 @@ class ElemMultExprImpl(ElemMultExprMixin): left: ExpressionBaseMixin right: ExpressionBaseMixin + def __str__(self): + return f"({self.left} * {self.right})" + @dataclassabc.dataclassabc(frozen=True) class EvalExprImpl(EvalExprMixin): @@ -151,6 +157,11 @@ class FilterExprImpl(FilterExprMixin): class FromNumbersExprImpl(FromNumbersExprMixin): data: tuple[tuple[int | float]] + def __str__(self): + if len(self.data) == 1: + if len(self.data[0]) == 1: + return str(self.data[0][0]) + @dataclassabc.dataclassabc(frozen=True) class FromNumpyExprImpl(FromNumpyExprMixin): @@ -237,9 +248,13 @@ class MatrixMultExprImpl(MatrixMultExprMixin): stack: tuple[FrameSummary] # implement custom __repr__ method that returns a representation without the stack + # FIXME: remove this? def __repr__(self): return f"{self.__class__.__name__}(left={self.left}, right={self.right})" + def __str__(self): + return f"({self.left} @ {self.right})" + @dataclassabc.dataclassabc(frozen=True) class DegreeExprImpl(DegreeExprMixin): @@ -367,6 +382,9 @@ class ToSymmetricMatrixExprImpl(ToSymmetricMatrixExprMixin): class TransposeExprImpl(TransposeExprMixin): underlying: ExpressionBaseMixin + def __str__(self): + return f"({self.underlying}).T" + @dataclassabc.dataclassabc(frozen=True) class TruncateExprImpl(TruncateExprMixin): @@ -381,7 +399,16 @@ class VariableImpl(VariableMixin): name: str shape: tuple[int, int] + def __str__(self): + return self.name + @dataclassabc.dataclassabc(frozen=True) class VStackExprImpl(VStackExprMixin): - underlying: tuple + underlying: tuple[ExpressionBaseMixin, ...] + + def __str__(self): + inner = ", ".join(map(str, self.underlying)) + return f"v_stack({inner})" + + diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py index e3b28c2..b18c89b 100644 --- a/polymatrix/expression/mixins/eyeexprmixin.py +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -39,6 +39,7 @@ class EyeExprMixin(ExpressionBaseMixin): else: raise Exception(f"{(row, col)=} is out of bounds") + # FIXME: this behaviour is counterintuitive, eye should take just a number for the dimension n_row = variable.shape[0] polymatrix = EyePolyMatrix( diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py index 8bc398f..f585240 100644 --- a/polymatrix/expression/mixins/symmetricexprmixin.py +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: from polymatrix.expressionstate.abc import ExpressionState from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.typing import PolyDict from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -41,6 +42,10 @@ class SymmetricExprMixin(ExpressionBaseMixin): def shape(self) -> tuple[int, int]: return self.underlying.shape + def at(self, row: int, col: int) -> PolyDict: + # FIXME: this is a quick workaround + return PolyDict(self.get_poly(row, col) or {}) + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: def gen_symmetric_monomials(): for i_row, i_col in ((row, col), (col, row)): diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index ab23ee6..2091f88 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -8,6 +8,7 @@ if typing.TYPE_CHECKING: from polymatrix.expressionstate.abc import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.polymatrix.typing import PolyDict from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix @@ -45,6 +46,10 @@ class VStackExprMixin(ExpressionBaseMixin): underlying_row_range: tuple[tuple[int, int], ...] shape: tuple[int, int] + def at(self, row: int, col: int) -> PolyDict: + # FIXME: this is a quick workaround + return self.get_poly(row, col) or PolyDict.empty() + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: for polymatrix, (row_start, row_end) in zip( self.all_underlying, self.underlying_row_range -- cgit v1.2.1