From d30823453e22d6bb53521997708c5903990492f9 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 27 May 2024 12:24:10 +0200 Subject: Create NamedExpr and polymatrix.give_name to give shorter names to expressions --- polymatrix/__init__.py | 2 ++ polymatrix/expression/__init__.py | 5 ++++ polymatrix/expression/impl.py | 10 ++++++++ polymatrix/expression/init.py | 5 ++++ polymatrix/expression/mixins/namedexprmixin.py | 32 ++++++++++++++++++++++++++ 5 files changed, 54 insertions(+) create mode 100644 polymatrix/expression/mixins/namedexprmixin.py diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 475749f..ba13779 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -20,6 +20,7 @@ from polymatrix.expression import ( concatenate as internal_concatenate, block_diag as internal_block_diag, lower_triangular as internal_lower_triangular, + give_name as internal_give_name, ) from polymatrix.expression.to import ( @@ -43,6 +44,7 @@ product = internal_product concatenate = internal_concatenate block_diag = internal_block_diag lower_triangular = internal_lower_triangular +give_name = internal_give_name to_constant_repr = internal_to_constant to_sympy_repr = internal_to_sympy diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 30be90b..56bc83a 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -13,6 +13,7 @@ from polymatrix.expression.init import ( init_block_diag_expr, init_concatenate_expr, init_lower_triangular_expr, + init_named_expr, init_ns_expr, ) @@ -103,3 +104,7 @@ def product( @convert_args_to_expression def lower_triangular(vector: Expression): return init_expression(init_lower_triangular_expr(vector.underlying)) + + +def give_name(expr: Expression, name: str): + return init_expression(init_named_expr(expr, name)) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index edc7b08..bbd932a 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -36,6 +36,7 @@ from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomial from polymatrix.expression.mixins.lowertriangularexprmixin import LowerTriangularExprMixin from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin +from polymatrix.expression.mixins.namedexprmixin import NamedExprMixin from polymatrix.expression.mixins.negationexprmixin import NegationExprMixin from polymatrix.expression.mixins.nsexprmixin import NsExprMixin from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin @@ -321,6 +322,15 @@ class MaxExprImpl(MaxExprMixin): underlying: ExpressionBaseMixin +@dataclassabc.dataclassabc(frozen=True) +class NamedExprImpl(NamedExprMixin): + underlying: ExpressionBaseMixin + name: str + + def __str__(self): + return str(self.name) + + @dataclassabc.dataclassabc(frozen=True) class NegationExprImpl(NegationExprMixin): underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 5604444..1db855f 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -298,6 +298,11 @@ def init_max_expr(underlying: ExpressionBaseMixin): underlying=underlying) +def init_named_expr(underlying: ExpressionBaseMixin, name: str): + return polymatrix.expression.impl.NamedExprImpl( + underlying=underlying, name=name) + + def init_negation_expr(underlying: ExpressionBaseMixin): return polymatrix.expression.impl.NegationExprImpl( underlying=underlying) diff --git a/polymatrix/expression/mixins/namedexprmixin.py b/polymatrix/expression/mixins/namedexprmixin.py new file mode 100644 index 0000000..e6d7284 --- /dev/null +++ b/polymatrix/expression/mixins/namedexprmixin.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.index import PolyDict, MonomialIndex +from polymatrix.polymatrix.init import init_broadcast_poly_matrix +from polymatrix.polymatrix.mixins import PolyMatrixMixin + +class NamedExprMixin(ExpressionBaseMixin): + """ + Give a name to an expression. + + This is mostly to make the printing more understandable. + See also NamedExprImpl.__str__. + """ + + @property + @abstractmethod + def underlying(self) -> ExpressionBaseMixin: + """ The expresssion """ + + @property + @abstractmethod + def name(self) -> str: + """ The name for the expression """ + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + return self.underlying.apply(state) -- cgit v1.2.1