diff options
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 84 |
1 files changed, 14 insertions, 70 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 7dbf1d2..e3580f0 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -1,9 +1,8 @@ import abc import math -from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl +from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix from polymatrix.utils.getstacklines import FrameSummary -from polymatrix.utils.tooperatorexception import to_operator_exception from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState @@ -26,31 +25,6 @@ class AdditionExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - @staticmethod - def broadcast(left: PolyMatrix, right: PolyMatrix, stack: tuple[FrameSummary]): - # broadcast left - if left.shape == (1, 1) and right.shape != (1, 1): - left = BroadcastPolyMatrixImpl( - polynomial=left.get_poly(0, 0), - shape=right.shape, - ) - - # broadcast right - elif left.shape != (1, 1) and right.shape == (1, 1): - right = BroadcastPolyMatrixImpl( - polynomial=right.get_poly(0, 0), - shape=left.shape, - ) - - else: - if not (left.shape == right.shape): - raise AssertionError(to_operator_exception( - message=f'{left.shape} != {right.shape}', - stack=stack, - )) - - return left, right - # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, @@ -59,44 +33,14 @@ class AdditionExprMixin(ExpressionBaseMixin): state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - # if left.shape == (1, 1): - # left, right = right, left - - # if left.shape != (1, 1) and right.shape == (1, 1): - - # # @dataclassabc.dataclassabc(frozen=True) - # # class BroadCastedPolyMatrix(PolyMatrixMixin): - # # underlying_monomials: tuple[tuple[int], float] - # # shape: tuple[int, int] - - # # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: - # # return self.underlying_monomials - - - # right = BroadcastPolyMatrixImpl( - # polynomial=right.get_poly(0, 0), - # shape=left.shape, - # ) - - # # all_underlying = (left, broadcasted_right) - - # else: - # if not (left.shape == right.shape): - # raise AssertionError(to_operator_exception( - # message=f'{left.shape} != {right.shape}', - # stack=self.stack, - # )) - - # # all_underlying = (left, right) - - left, right = self.broadcast(left, right, self.stack) + left, right = broadcast_poly_matrix(left, right, self.stack) - terms = {} + poly_matrix_data = {} for row in range(left.shape[0]): for col in range(left.shape[1]): - terms_row_col = {} + poly_data = {} for underlying in (left, right): @@ -104,26 +48,26 @@ class AdditionExprMixin(ExpressionBaseMixin): if polynomial is None: continue - if len(terms_row_col) == 0: - terms_row_col = dict(polynomial) + if len(poly_data) == 0: + poly_data = dict(polynomial) else: for monomial, value in polynomial.items(): - if monomial not in terms_row_col: - terms_row_col[monomial] = value + if monomial not in poly_data: + poly_data[monomial] = value else: - terms_row_col[monomial] += value + poly_data[monomial] += value - if math.isclose(terms_row_col[monomial], 0): - del terms_row_col[monomial] + if math.isclose(poly_data[monomial], 0): + del poly_data[monomial] - if 0 < len(terms_row_col): - terms[row, col] = terms_row_col + if 0 < len(poly_data): + poly_matrix_data[row, col] = poly_data poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=left.shape, ) |