summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-04 19:04:03 +0200
committerNao Pross <np@0hm.ch>2024-05-04 19:07:11 +0200
commit3f038881247094b8a13a8b6417007703655896dd (patch)
treea142d5c986c9ca546e680c674f228f064326ce50
parentReintroduce PolyMatrixDict.__getitem__ to make the typechecker shut up (diff)
downloadpolymatrix-3f038881247094b8a13a8b6417007703655896dd.tar.gz
polymatrix-3f038881247094b8a13a8b6417007703655896dd.zip
Fix v_stack, add __str__ representation to some operation mixins
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/__init__.py2
-rw-r--r--polymatrix/expression/expression.py10
-rw-r--r--polymatrix/expression/impl.py29
-rw-r--r--polymatrix/expression/mixins/eyeexprmixin.py1
-rw-r--r--polymatrix/expression/mixins/symmetricexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/vstackexprmixin.py5
6 files changed, 45 insertions, 7 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index 6f7eae2..b150cc5 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -15,7 +15,7 @@ def v_stack(
def gen_underlying():
for expr in expressions:
if isinstance(expr, Expression):
- yield expr
+ yield expr.underlying
else:
yield polymatrix.expression.from_.from_(expr)
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 3c52297..fbe7ac3 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -493,12 +493,12 @@ class Expression(ExpressionBaseMixin, ABC):
)
-@dataclassabc(frozen=True, repr=False)
+@dataclassabc(frozen=True)
class ExpressionImpl(Expression):
underlying: ExpressionBaseMixin
- def __repr__(self) -> str:
- return self.underlying.__repr__()
+ def __str__(self):
+ return f"Expression({self.underlying})"
def copy(self, underlying: ExpressionBaseMixin) -> Expression:
return dataclasses.replace(
@@ -539,8 +539,8 @@ class VariableExpression(Expression, Variable):
class VariableExpressionImpl(VariableExpression):
underlying: ExpressionBaseMixin
- def __repr__(self) -> str:
- return self.underlying.__repr__()
+ def __str__(self):
+ return f"VariableExpression({self.underlying})"
def copy(self, underlying: ExpressionBaseMixin) -> Expression:
return init_expression(underlying)
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 119d8c7..600ad81 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -83,6 +83,9 @@ class AdditionExprImpl(AdditionExprMixin):
def __repr__(self):
return f"{self.__class__.__name__}(left={self.left}, right={self.right})"
+ def __str__(self):
+ return f"({self.left} + {self.right})"
+
@dataclassabc.dataclassabc(frozen=True)
class BlockDiagExprImpl(BlockDiagExprMixin):
@@ -128,6 +131,9 @@ class ElemMultExprImpl(ElemMultExprMixin):
left: ExpressionBaseMixin
right: ExpressionBaseMixin
+ def __str__(self):
+ return f"({self.left} * {self.right})"
+
@dataclassabc.dataclassabc(frozen=True)
class EvalExprImpl(EvalExprMixin):
@@ -151,6 +157,11 @@ class FilterExprImpl(FilterExprMixin):
class FromNumbersExprImpl(FromNumbersExprMixin):
data: tuple[tuple[int | float]]
+ def __str__(self):
+ if len(self.data) == 1:
+ if len(self.data[0]) == 1:
+ return str(self.data[0][0])
+
@dataclassabc.dataclassabc(frozen=True)
class FromNumpyExprImpl(FromNumpyExprMixin):
@@ -237,9 +248,13 @@ class MatrixMultExprImpl(MatrixMultExprMixin):
stack: tuple[FrameSummary]
# implement custom __repr__ method that returns a representation without the stack
+ # FIXME: remove this?
def __repr__(self):
return f"{self.__class__.__name__}(left={self.left}, right={self.right})"
+ def __str__(self):
+ return f"({self.left} @ {self.right})"
+
@dataclassabc.dataclassabc(frozen=True)
class DegreeExprImpl(DegreeExprMixin):
@@ -367,6 +382,9 @@ class ToSymmetricMatrixExprImpl(ToSymmetricMatrixExprMixin):
class TransposeExprImpl(TransposeExprMixin):
underlying: ExpressionBaseMixin
+ def __str__(self):
+ return f"({self.underlying}).T"
+
@dataclassabc.dataclassabc(frozen=True)
class TruncateExprImpl(TruncateExprMixin):
@@ -381,7 +399,16 @@ class VariableImpl(VariableMixin):
name: str
shape: tuple[int, int]
+ def __str__(self):
+ return self.name
+
@dataclassabc.dataclassabc(frozen=True)
class VStackExprImpl(VStackExprMixin):
- underlying: tuple
+ underlying: tuple[ExpressionBaseMixin, ...]
+
+ def __str__(self):
+ inner = ", ".join(map(str, self.underlying))
+ return f"v_stack({inner})"
+
+
diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py
index e3b28c2..b18c89b 100644
--- a/polymatrix/expression/mixins/eyeexprmixin.py
+++ b/polymatrix/expression/mixins/eyeexprmixin.py
@@ -39,6 +39,7 @@ class EyeExprMixin(ExpressionBaseMixin):
else:
raise Exception(f"{(row, col)=} is out of bounds")
+ # FIXME: this behaviour is counterintuitive, eye should take just a number for the dimension
n_row = variable.shape[0]
polymatrix = EyePolyMatrix(
diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py
index 8bc398f..f585240 100644
--- a/polymatrix/expression/mixins/symmetricexprmixin.py
+++ b/polymatrix/expression/mixins/symmetricexprmixin.py
@@ -9,6 +9,7 @@ if typing.TYPE_CHECKING:
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.typing import PolyDict
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -41,6 +42,10 @@ class SymmetricExprMixin(ExpressionBaseMixin):
def shape(self) -> tuple[int, int]:
return self.underlying.shape
+ def at(self, row: int, col: int) -> PolyDict:
+ # FIXME: this is a quick workaround
+ return PolyDict(self.get_poly(row, col) or {})
+
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
def gen_symmetric_monomials():
for i_row, i_col in ((row, col), (col, row)):
diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py
index ab23ee6..2091f88 100644
--- a/polymatrix/expression/mixins/vstackexprmixin.py
+++ b/polymatrix/expression/mixins/vstackexprmixin.py
@@ -8,6 +8,7 @@ if typing.TYPE_CHECKING:
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.polymatrix.typing import PolyDict
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
@@ -45,6 +46,10 @@ class VStackExprMixin(ExpressionBaseMixin):
underlying_row_range: tuple[tuple[int, int], ...]
shape: tuple[int, int]
+ def at(self, row: int, col: int) -> PolyDict:
+ # FIXME: this is a quick workaround
+ return self.get_poly(row, col) or PolyDict.empty()
+
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
for polymatrix, (row_start, row_end) in zip(
self.all_underlying, self.underlying_row_range