summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-27 10:44:25 +0200
committerNao Pross <np@0hm.ch>2024-05-27 10:44:25 +0200
commitddc3fa1ce8d482cd027e73a8244dcfe3338bc648 (patch)
tree5a01a89041d5a3d2be9347a9782ccf33fb83eebb
parentFix bug in .diff() and add DerivativeExpr.__str__ (diff)
downloadpolymatrix-ddc3fa1ce8d482cd027e73a8244dcfe3338bc648.tar.gz
polymatrix-ddc3fa1ce8d482cd027e73a8244dcfe3338bc648.zip
Clean up AdditionExpr, create SubtractinExpr for pretty printing
-rw-r--r--polymatrix/expression/expression.py4
-rw-r--r--polymatrix/expression/impl.py15
-rw-r--r--polymatrix/expression/init.py11
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py24
-rw-r--r--polymatrix/expression/mixins/subtractionexprmixin.py57
5 files changed, 100 insertions, 11 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 58a3883..3151d25 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -94,7 +94,9 @@ class Expression(ExpressionBaseMixin, ABC):
return other + (-self)
def __sub__(self, other):
- return self + (-other)
+ return self._binary(
+ polymatrix.expression.init.init_subtraction_expr, other, self
+ )
def __truediv__(self, other):
if not isinstance(other, float | int):
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index aedc7da..e0b2aa6 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -48,6 +48,7 @@ from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprM
from polymatrix.expression.mixins.shapeexprmixin import ShapeExprMixin
from polymatrix.expression.mixins.sliceexprmixin import SliceExprMixin
from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin
+from polymatrix.expression.mixins.subtractionexprmixin import SubtractionExprMixin
from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin
from polymatrix.expression.mixins.sumexprmixin import SumExprMixin
from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin
@@ -420,6 +421,20 @@ class SqueezeExprImpl(SqueezeExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class SubtractionExprImpl(SubtractionExprMixin):
+ left: ExpressionBaseMixin
+ right: ExpressionBaseMixin
+ stack: tuple[FrameSummary]
+
+ # implement custom __repr__ method that returns a representation without the stack
+ def __repr__(self):
+ return f"{self.__class__.__name__}(left={repr(self.left)}, right={repr(self.right)})"
+
+ def __str__(self):
+ return f"({self.left} - {self.right})"
+
+
+@dataclassabc.dataclassabc(frozen=True)
class SubtractMonomialsExprImpl(SubtractMonomialsExprMixin):
underlying: ExpressionBaseMixin
monomials: ExpressionBaseMixin
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 2740d38..849f017 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -407,6 +407,17 @@ def init_squeeze_expr(
underlying=underlying,
)
+def init_subtraction_expr(
+ left: ExpressionBaseMixin,
+ right: ExpressionBaseMixin,
+ stack: tuple[FrameSummary],
+):
+ return polymatrix.expression.impl.SubtractionExprImpl(
+ left=left,
+ right=right,
+ stack=stack,
+ )
+
def init_subtract_monomials_expr(
underlying: ExpressionBaseMixin,
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 4b0a9c9..836de00 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -1,10 +1,14 @@
-import abc
+from __future__ import annotations
+
import math
+from abc import abstractmethod
+from typing_extensions import override
+
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.polymatrix.index import PolyMatrixDict
-from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expressionstate import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix
@@ -22,22 +26,19 @@ class AdditionExprMixin(ExpressionBaseMixin):
"""
@property
- @abc.abstractmethod
+ @abstractmethod
def left(self) -> ExpressionBaseMixin: ...
@property
- @abc.abstractmethod
+ @abstractmethod
def right(self) -> ExpressionBaseMixin: ...
@property
- @abc.abstractmethod
+ @abstractmethod
def stack(self) -> tuple[FrameSummary]: ...
- # overwrites the abstract method of `ExpressionBaseMixin`
- def apply(
- self,
- state: ExpressionState,
- ) -> tuple[ExpressionState, PolyMatrix]:
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
@@ -50,6 +51,9 @@ class AdditionExprMixin(ExpressionBaseMixin):
for monomial, coeff in right_poly.terms():
if monomial in new_poly:
new_poly[monomial] += coeff
+
+ if math.isclose(new_poly[monomial], 0):
+ del new_poly[monomial]
else:
new_poly[monomial] = coeff
diff --git a/polymatrix/expression/mixins/subtractionexprmixin.py b/polymatrix/expression/mixins/subtractionexprmixin.py
new file mode 100644
index 0000000..41b2509
--- /dev/null
+++ b/polymatrix/expression/mixins/subtractionexprmixin.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import math
+
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.utils.getstacklines import FrameSummary
+from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.index import PolyMatrixDict
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix
+
+
+class SubtractionExprMixin(ExpressionBaseMixin):
+ """
+ Subtract two expressions. If one of the two expression is scalar, it is
+ broadcast to match the shape of the other.
+ """
+
+ @property
+ @abstractmethod
+ def left(self) -> ExpressionBaseMixin: ...
+
+ @property
+ @abstractmethod
+ def right(self) -> ExpressionBaseMixin: ...
+
+ @property
+ @abstractmethod
+ def stack(self) -> tuple[FrameSummary]: ...
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
+ state, left = self.left.apply(state=state)
+ state, right = self.right.apply(state=state)
+
+ left, right = broadcast_poly_matrix(left, right, self.stack)
+
+ # Keep left as-is and subtract the stuff from right
+ result = PolyMatrixDict.empty()
+ for entry, right_poly in right.entries():
+ new_poly = left.at(*entry)
+ for monomial, coeff in right_poly.terms():
+ if monomial in new_poly:
+ new_poly[monomial] -= coeff
+
+ if math.isclose(new_poly[monomial], 0):
+ del new_poly[monomial]
+ else:
+ new_poly[monomial] = coeff
+
+ result[*entry] = new_poly
+ return state, init_poly_matrix(result, left.shape)
+