summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/additionexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 4b0a9c9..836de00 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -1,10 +1,14 @@
-import abc
+from __future__ import annotations
+
import math
+from abc import abstractmethod
+from typing_extensions import override
+
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.polymatrix.index import PolyMatrixDict
-from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expressionstate import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix
@@ -22,22 +26,19 @@ class AdditionExprMixin(ExpressionBaseMixin):
"""
@property
- @abc.abstractmethod
+ @abstractmethod
def left(self) -> ExpressionBaseMixin: ...
@property
- @abc.abstractmethod
+ @abstractmethod
def right(self) -> ExpressionBaseMixin: ...
@property
- @abc.abstractmethod
+ @abstractmethod
def stack(self) -> tuple[FrameSummary]: ...
- # overwrites the abstract method of `ExpressionBaseMixin`
- def apply(
- self,
- state: ExpressionState,
- ) -> tuple[ExpressionState, PolyMatrix]:
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
@@ -50,6 +51,9 @@ class AdditionExprMixin(ExpressionBaseMixin):
for monomial, coeff in right_poly.terms():
if monomial in new_poly:
new_poly[monomial] += coeff
+
+ if math.isclose(new_poly[monomial], 0):
+ del new_poly[monomial]
else:
new_poly[monomial] = coeff