summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/additionexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py79
1 files changed, 51 insertions, 28 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 2cbbe1e..7dbf1d2 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -1,12 +1,10 @@
import abc
import math
-import typing
-import dataclassabc
+from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.tooperatorexception import to_operator_exception
from polymatrix.polymatrix.init import init_poly_matrix
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -28,7 +26,32 @@ class AdditionExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ @staticmethod
+ def broadcast(left: PolyMatrix, right: PolyMatrix, stack: tuple[FrameSummary]):
+ # broadcast left
+ if left.shape == (1, 1) and right.shape != (1, 1):
+ left = BroadcastPolyMatrixImpl(
+ polynomial=left.get_poly(0, 0),
+ shape=right.shape,
+ )
+
+ # broadcast right
+ elif left.shape != (1, 1) and right.shape == (1, 1):
+ right = BroadcastPolyMatrixImpl(
+ polynomial=right.get_poly(0, 0),
+ shape=left.shape,
+ )
+
+ else:
+ if not (left.shape == right.shape):
+ raise AssertionError(to_operator_exception(
+ message=f'{left.shape} != {right.shape}',
+ stack=stack,
+ ))
+
+ return left, right
+
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
@@ -36,38 +59,37 @@ class AdditionExprMixin(ExpressionBaseMixin):
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- if left.shape == (1, 1):
- left, right = right, left
+ # if left.shape == (1, 1):
+ # left, right = right, left
- if left.shape != (1, 1) and right.shape == (1, 1):
+ # if left.shape != (1, 1) and right.shape == (1, 1):
- @dataclassabc.dataclassabc(frozen=True)
- class BroadCastedPolyMatrix(PolyMatrixMixin):
- underlying_monomials: tuple[tuple[int], float]
- shape: tuple[int, int]
+ # # @dataclassabc.dataclassabc(frozen=True)
+ # # class BroadCastedPolyMatrix(PolyMatrixMixin):
+ # # underlying_monomials: tuple[tuple[int], float]
+ # # shape: tuple[int, int]
- def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
- return self.underlying_monomials
+ # # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
+ # # return self.underlying_monomials
- polynomial = right.get_poly(0, 0)
- if polynomial is not None:
+ # right = BroadcastPolyMatrixImpl(
+ # polynomial=right.get_poly(0, 0),
+ # shape=left.shape,
+ # )
- broadcasted_right = BroadCastedPolyMatrix(
- underlying_monomials=polynomial,
- shape=left.shape,
- )
+ # # all_underlying = (left, broadcasted_right)
- all_underlying = (left, broadcasted_right)
+ # else:
+ # if not (left.shape == right.shape):
+ # raise AssertionError(to_operator_exception(
+ # message=f'{left.shape} != {right.shape}',
+ # stack=self.stack,
+ # ))
- else:
- if not (left.shape == right.shape):
- raise AssertionError(to_operator_exception(
- message=f'{left.shape} != {right.shape}',
- stack=self.stack,
- ))
+ # # all_underlying = (left, right)
- all_underlying = (left, right)
+ left, right = self.broadcast(left, right, self.stack)
terms = {}
@@ -76,7 +98,7 @@ class AdditionExprMixin(ExpressionBaseMixin):
terms_row_col = {}
- for underlying in all_underlying:
+ for underlying in (left, right):
polynomial = underlying.get_poly(row, col)
if polynomial is None:
@@ -90,6 +112,7 @@ class AdditionExprMixin(ExpressionBaseMixin):
if monomial not in terms_row_col:
terms_row_col[monomial] = value
+
else:
terms_row_col[monomial] += value