summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/expression.py2
-rw-r--r--polymatrix/expression/impl.py9
-rw-r--r--polymatrix/expression/init.py14
-rw-r--r--polymatrix/expression/mixins/negationexprmixin.py33
4 files changed, 50 insertions, 8 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 3151d25..63bdf6e 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -77,7 +77,7 @@ class Expression(ExpressionBaseMixin, ABC):
return self._binary(polymatrix.expression.init.init_power_expr, self, exponent)
def __neg__(self):
- return self * (-1)
+ return init_expression(polymatrix.expression.init.init_negation_expr(self.underlying))
def __radd__(self, other):
return self._binary(polymatrix.expression.init.init_addition_expr, other, self)
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index e0b2aa6..edc7b08 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -36,6 +36,7 @@ from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomial
from polymatrix.expression.mixins.lowertriangularexprmixin import LowerTriangularExprMixin
from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin
from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin
+from polymatrix.expression.mixins.negationexprmixin import NegationExprMixin
from polymatrix.expression.mixins.nsexprmixin import NsExprMixin
from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin
@@ -321,6 +322,14 @@ class MaxExprImpl(MaxExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class NegationExprImpl(NegationExprMixin):
+ underlying: ExpressionBaseMixin
+
+ def __str__(self):
+ return f"(-{self.underlying})"
+
+
+@dataclassabc.dataclassabc(frozen=True)
class NsExprImpl(NsExprMixin):
n: int | float | ExpressionBaseMixin
shape: tuple[int, int] | ExpressionBaseMixin
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 849f017..5604444 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -165,8 +165,6 @@ def init_from_statemonad(monad: StateMonad):
# TODO: remove this function, replaced by from_any
-# NP: this function should be split up into smaller functions, one for each "from" type
-# NP: and each "from" should be documented, explaining how it is interpreted.
def init_from_expr_or_none(
data: FromSupportedTypes,
) -> ExpressionBaseMixin | None:
@@ -295,12 +293,14 @@ def init_matrix_mult_expr(
)
-def init_max_expr(
- underlying: ExpressionBaseMixin,
-):
+def init_max_expr(underlying: ExpressionBaseMixin):
return polymatrix.expression.impl.MaxExprImpl(
- underlying=underlying,
- )
+ underlying=underlying)
+
+
+def init_negation_expr(underlying: ExpressionBaseMixin):
+ return polymatrix.expression.impl.NegationExprImpl(
+ underlying=underlying)
def init_ns_expr(n: int | float | ExpressionBaseMixin, shape: tuple[int, int] | ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/negationexprmixin.py b/polymatrix/expression/mixins/negationexprmixin.py
new file mode 100644
index 0000000..eaaaff3
--- /dev/null
+++ b/polymatrix/expression/mixins/negationexprmixin.py
@@ -0,0 +1,33 @@
+from __future__ import annotations
+
+import math
+
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+
+class NegationExprMixin(ExpressionBaseMixin):
+ """ Negate expression, or, multiply by -1. """
+
+ @property
+ @abstractmethod
+ def underlying(self):
+ """ The expression that will be negated. """
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
+ state, u = self.underlying.apply(state)
+ return state, init_poly_matrix(PolyMatrixDict({
+ entry: PolyDict({
+ monomial: -coeff
+ })
+ for entry, poly in u.entries()
+ for monomial, coeff in poly.terms()
+ }), u.shape)
+