summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-27 12:24:10 +0200
committerNao Pross <np@0hm.ch>2024-05-27 12:24:10 +0200
commitd30823453e22d6bb53521997708c5903990492f9 (patch)
treec01ede0aef823f6e5be6084cca50db8d53a892ba
parentFix pretty printing of PolyMatrix objects (diff)
downloadpolymatrix-d30823453e22d6bb53521997708c5903990492f9.tar.gz
polymatrix-d30823453e22d6bb53521997708c5903990492f9.zip
Create NamedExpr and polymatrix.give_name to give shorter names to expressions
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py2
-rw-r--r--polymatrix/expression/__init__.py5
-rw-r--r--polymatrix/expression/impl.py10
-rw-r--r--polymatrix/expression/init.py5
-rw-r--r--polymatrix/expression/mixins/namedexprmixin.py32
5 files changed, 54 insertions, 0 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 475749f..ba13779 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -20,6 +20,7 @@ from polymatrix.expression import (
concatenate as internal_concatenate,
block_diag as internal_block_diag,
lower_triangular as internal_lower_triangular,
+ give_name as internal_give_name,
)
from polymatrix.expression.to import (
@@ -43,6 +44,7 @@ product = internal_product
concatenate = internal_concatenate
block_diag = internal_block_diag
lower_triangular = internal_lower_triangular
+give_name = internal_give_name
to_constant_repr = internal_to_constant
to_sympy_repr = internal_to_sympy
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index 30be90b..56bc83a 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -13,6 +13,7 @@ from polymatrix.expression.init import (
init_block_diag_expr,
init_concatenate_expr,
init_lower_triangular_expr,
+ init_named_expr,
init_ns_expr,
)
@@ -103,3 +104,7 @@ def product(
@convert_args_to_expression
def lower_triangular(vector: Expression):
return init_expression(init_lower_triangular_expr(vector.underlying))
+
+
+def give_name(expr: Expression, name: str):
+ return init_expression(init_named_expr(expr, name))
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index edc7b08..bbd932a 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -36,6 +36,7 @@ from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomial
from polymatrix.expression.mixins.lowertriangularexprmixin import LowerTriangularExprMixin
from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin
from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin
+from polymatrix.expression.mixins.namedexprmixin import NamedExprMixin
from polymatrix.expression.mixins.negationexprmixin import NegationExprMixin
from polymatrix.expression.mixins.nsexprmixin import NsExprMixin
from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
@@ -322,6 +323,15 @@ class MaxExprImpl(MaxExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class NamedExprImpl(NamedExprMixin):
+ underlying: ExpressionBaseMixin
+ name: str
+
+ def __str__(self):
+ return str(self.name)
+
+
+@dataclassabc.dataclassabc(frozen=True)
class NegationExprImpl(NegationExprMixin):
underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 5604444..1db855f 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -298,6 +298,11 @@ def init_max_expr(underlying: ExpressionBaseMixin):
underlying=underlying)
+def init_named_expr(underlying: ExpressionBaseMixin, name: str):
+ return polymatrix.expression.impl.NamedExprImpl(
+ underlying=underlying, name=name)
+
+
def init_negation_expr(underlying: ExpressionBaseMixin):
return polymatrix.expression.impl.NegationExprImpl(
underlying=underlying)
diff --git a/polymatrix/expression/mixins/namedexprmixin.py b/polymatrix/expression/mixins/namedexprmixin.py
new file mode 100644
index 0000000..e6d7284
--- /dev/null
+++ b/polymatrix/expression/mixins/namedexprmixin.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.index import PolyDict, MonomialIndex
+from polymatrix.polymatrix.init import init_broadcast_poly_matrix
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+
+class NamedExprMixin(ExpressionBaseMixin):
+ """
+ Give a name to an expression.
+
+ This is mostly to make the printing more understandable.
+ See also NamedExprImpl.__str__.
+ """
+
+ @property
+ @abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ """ The expresssion """
+
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ """ The name for the expression """
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
+ return self.underlying.apply(state)