diff options
-rw-r--r-- | polymatrix/expression/expression.py | 4 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 15 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 11 | ||||
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 24 | ||||
-rw-r--r-- | polymatrix/expression/mixins/subtractionexprmixin.py | 57 |
5 files changed, 100 insertions, 11 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 58a3883..3151d25 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -94,7 +94,9 @@ class Expression(ExpressionBaseMixin, ABC): return other + (-self) def __sub__(self, other): - return self + (-other) + return self._binary( + polymatrix.expression.init.init_subtraction_expr, other, self + ) def __truediv__(self, other): if not isinstance(other, float | int): diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index aedc7da..e0b2aa6 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -48,6 +48,7 @@ from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprM from polymatrix.expression.mixins.shapeexprmixin import ShapeExprMixin from polymatrix.expression.mixins.sliceexprmixin import SliceExprMixin from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin +from polymatrix.expression.mixins.subtractionexprmixin import SubtractionExprMixin from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin from polymatrix.expression.mixins.sumexprmixin import SumExprMixin from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin @@ -420,6 +421,20 @@ class SqueezeExprImpl(SqueezeExprMixin): @dataclassabc.dataclassabc(frozen=True) +class SubtractionExprImpl(SubtractionExprMixin): + left: ExpressionBaseMixin + right: ExpressionBaseMixin + stack: tuple[FrameSummary] + + # implement custom __repr__ method that returns a representation without the stack + def __repr__(self): + return f"{self.__class__.__name__}(left={repr(self.left)}, right={repr(self.right)})" + + def __str__(self): + return f"({self.left} - {self.right})" + + +@dataclassabc.dataclassabc(frozen=True) class SubtractMonomialsExprImpl(SubtractMonomialsExprMixin): underlying: ExpressionBaseMixin monomials: ExpressionBaseMixin diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 2740d38..849f017 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -407,6 +407,17 @@ def init_squeeze_expr( underlying=underlying, ) +def init_subtraction_expr( + left: ExpressionBaseMixin, + right: ExpressionBaseMixin, + stack: tuple[FrameSummary], +): + return polymatrix.expression.impl.SubtractionExprImpl( + left=left, + right=right, + stack=stack, + ) + def init_subtract_monomials_expr( underlying: ExpressionBaseMixin, diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 4b0a9c9..836de00 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -1,10 +1,14 @@ -import abc +from __future__ import annotations + import math +from abc import abstractmethod +from typing_extensions import override + from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.index import PolyMatrixDict -from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix @@ -22,22 +26,19 @@ class AdditionExprMixin(ExpressionBaseMixin): """ @property - @abc.abstractmethod + @abstractmethod def left(self) -> ExpressionBaseMixin: ... @property - @abc.abstractmethod + @abstractmethod def right(self) -> ExpressionBaseMixin: ... @property - @abc.abstractmethod + @abstractmethod def stack(self) -> tuple[FrameSummary]: ... - # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) @@ -50,6 +51,9 @@ class AdditionExprMixin(ExpressionBaseMixin): for monomial, coeff in right_poly.terms(): if monomial in new_poly: new_poly[monomial] += coeff + + if math.isclose(new_poly[monomial], 0): + del new_poly[monomial] else: new_poly[monomial] = coeff diff --git a/polymatrix/expression/mixins/subtractionexprmixin.py b/polymatrix/expression/mixins/subtractionexprmixin.py new file mode 100644 index 0000000..41b2509 --- /dev/null +++ b/polymatrix/expression/mixins/subtractionexprmixin.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import math + +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.utils.getstacklines import FrameSummary +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.index import PolyMatrixDict +from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix + + +class SubtractionExprMixin(ExpressionBaseMixin): + """ + Subtract two expressions. If one of the two expression is scalar, it is + broadcast to match the shape of the other. + """ + + @property + @abstractmethod + def left(self) -> ExpressionBaseMixin: ... + + @property + @abstractmethod + def right(self) -> ExpressionBaseMixin: ... + + @property + @abstractmethod + def stack(self) -> tuple[FrameSummary]: ... + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + state, left = self.left.apply(state=state) + state, right = self.right.apply(state=state) + + left, right = broadcast_poly_matrix(left, right, self.stack) + + # Keep left as-is and subtract the stuff from right + result = PolyMatrixDict.empty() + for entry, right_poly in right.entries(): + new_poly = left.at(*entry) + for monomial, coeff in right_poly.terms(): + if monomial in new_poly: + new_poly[monomial] -= coeff + + if math.isclose(new_poly[monomial], 0): + del new_poly[monomial] + else: + new_poly[monomial] = coeff + + result[*entry] = new_poly + return state, init_poly_matrix(result, left.shape) + |