summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/additionexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py84
1 files changed, 14 insertions, 70 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 7dbf1d2..e3580f0 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -1,9 +1,8 @@
import abc
import math
-from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl
+from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix
from polymatrix.utils.getstacklines import FrameSummary
-from polymatrix.utils.tooperatorexception import to_operator_exception
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
@@ -26,31 +25,6 @@ class AdditionExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- @staticmethod
- def broadcast(left: PolyMatrix, right: PolyMatrix, stack: tuple[FrameSummary]):
- # broadcast left
- if left.shape == (1, 1) and right.shape != (1, 1):
- left = BroadcastPolyMatrixImpl(
- polynomial=left.get_poly(0, 0),
- shape=right.shape,
- )
-
- # broadcast right
- elif left.shape != (1, 1) and right.shape == (1, 1):
- right = BroadcastPolyMatrixImpl(
- polynomial=right.get_poly(0, 0),
- shape=left.shape,
- )
-
- else:
- if not (left.shape == right.shape):
- raise AssertionError(to_operator_exception(
- message=f'{left.shape} != {right.shape}',
- stack=stack,
- ))
-
- return left, right
-
# overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
@@ -59,44 +33,14 @@ class AdditionExprMixin(ExpressionBaseMixin):
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- # if left.shape == (1, 1):
- # left, right = right, left
-
- # if left.shape != (1, 1) and right.shape == (1, 1):
-
- # # @dataclassabc.dataclassabc(frozen=True)
- # # class BroadCastedPolyMatrix(PolyMatrixMixin):
- # # underlying_monomials: tuple[tuple[int], float]
- # # shape: tuple[int, int]
-
- # # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
- # # return self.underlying_monomials
-
-
- # right = BroadcastPolyMatrixImpl(
- # polynomial=right.get_poly(0, 0),
- # shape=left.shape,
- # )
-
- # # all_underlying = (left, broadcasted_right)
-
- # else:
- # if not (left.shape == right.shape):
- # raise AssertionError(to_operator_exception(
- # message=f'{left.shape} != {right.shape}',
- # stack=self.stack,
- # ))
-
- # # all_underlying = (left, right)
-
- left, right = self.broadcast(left, right, self.stack)
+ left, right = broadcast_poly_matrix(left, right, self.stack)
- terms = {}
+ poly_matrix_data = {}
for row in range(left.shape[0]):
for col in range(left.shape[1]):
- terms_row_col = {}
+ poly_data = {}
for underlying in (left, right):
@@ -104,26 +48,26 @@ class AdditionExprMixin(ExpressionBaseMixin):
if polynomial is None:
continue
- if len(terms_row_col) == 0:
- terms_row_col = dict(polynomial)
+ if len(poly_data) == 0:
+ poly_data = dict(polynomial)
else:
for monomial, value in polynomial.items():
- if monomial not in terms_row_col:
- terms_row_col[monomial] = value
+ if monomial not in poly_data:
+ poly_data[monomial] = value
else:
- terms_row_col[monomial] += value
+ poly_data[monomial] += value
- if math.isclose(terms_row_col[monomial], 0):
- del terms_row_col[monomial]
+ if math.isclose(poly_data[monomial], 0):
+ del poly_data[monomial]
- if 0 < len(terms_row_col):
- terms[row, col] = terms_row_col
+ if 0 < len(poly_data):
+ poly_matrix_data[row, col] = poly_data
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=left.shape,
)