diff options
-rw-r--r-- | polymatrix/expression/expression.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 9 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 14 | ||||
-rw-r--r-- | polymatrix/expression/mixins/negationexprmixin.py | 33 |
4 files changed, 50 insertions, 8 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 3151d25..63bdf6e 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -77,7 +77,7 @@ class Expression(ExpressionBaseMixin, ABC): return self._binary(polymatrix.expression.init.init_power_expr, self, exponent) def __neg__(self): - return self * (-1) + return init_expression(polymatrix.expression.init.init_negation_expr(self.underlying)) def __radd__(self, other): return self._binary(polymatrix.expression.init.init_addition_expr, other, self) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index e0b2aa6..edc7b08 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -36,6 +36,7 @@ from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomial from polymatrix.expression.mixins.lowertriangularexprmixin import LowerTriangularExprMixin from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin +from polymatrix.expression.mixins.negationexprmixin import NegationExprMixin from polymatrix.expression.mixins.nsexprmixin import NsExprMixin from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin @@ -321,6 +322,14 @@ class MaxExprImpl(MaxExprMixin): @dataclassabc.dataclassabc(frozen=True) +class NegationExprImpl(NegationExprMixin): + underlying: ExpressionBaseMixin + + def __str__(self): + return f"(-{self.underlying})" + + +@dataclassabc.dataclassabc(frozen=True) class NsExprImpl(NsExprMixin): n: int | float | ExpressionBaseMixin shape: tuple[int, int] | ExpressionBaseMixin diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 849f017..5604444 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -165,8 +165,6 @@ def init_from_statemonad(monad: StateMonad): # TODO: remove this function, replaced by from_any -# NP: this function should be split up into smaller functions, one for each "from" type -# NP: and each "from" should be documented, explaining how it is interpreted. def init_from_expr_or_none( data: FromSupportedTypes, ) -> ExpressionBaseMixin | None: @@ -295,12 +293,14 @@ def init_matrix_mult_expr( ) -def init_max_expr( - underlying: ExpressionBaseMixin, -): +def init_max_expr(underlying: ExpressionBaseMixin): return polymatrix.expression.impl.MaxExprImpl( - underlying=underlying, - ) + underlying=underlying) + + +def init_negation_expr(underlying: ExpressionBaseMixin): + return polymatrix.expression.impl.NegationExprImpl( + underlying=underlying) def init_ns_expr(n: int | float | ExpressionBaseMixin, shape: tuple[int, int] | ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/negationexprmixin.py b/polymatrix/expression/mixins/negationexprmixin.py new file mode 100644 index 0000000..eaaaff3 --- /dev/null +++ b/polymatrix/expression/mixins/negationexprmixin.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import math + +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict +from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + + +class NegationExprMixin(ExpressionBaseMixin): + """ Negate expression, or, multiply by -1. """ + + @property + @abstractmethod + def underlying(self): + """ The expression that will be negated. """ + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + state, u = self.underlying.apply(state) + return state, init_poly_matrix(PolyMatrixDict({ + entry: PolyDict({ + monomial: -coeff + }) + for entry, poly in u.entries() + for monomial, coeff in poly.terms() + }), u.shape) + |