diff options
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 79 |
1 files changed, 51 insertions, 28 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 2cbbe1e..7dbf1d2 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -1,12 +1,10 @@ import abc import math -import typing -import dataclassabc +from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.tooperatorexception import to_operator_exception from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -28,7 +26,32 @@ class AdditionExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + @staticmethod + def broadcast(left: PolyMatrix, right: PolyMatrix, stack: tuple[FrameSummary]): + # broadcast left + if left.shape == (1, 1) and right.shape != (1, 1): + left = BroadcastPolyMatrixImpl( + polynomial=left.get_poly(0, 0), + shape=right.shape, + ) + + # broadcast right + elif left.shape != (1, 1) and right.shape == (1, 1): + right = BroadcastPolyMatrixImpl( + polynomial=right.get_poly(0, 0), + shape=left.shape, + ) + + else: + if not (left.shape == right.shape): + raise AssertionError(to_operator_exception( + message=f'{left.shape} != {right.shape}', + stack=stack, + )) + + return left, right + + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, @@ -36,38 +59,37 @@ class AdditionExprMixin(ExpressionBaseMixin): state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - if left.shape == (1, 1): - left, right = right, left + # if left.shape == (1, 1): + # left, right = right, left - if left.shape != (1, 1) and right.shape == (1, 1): + # if left.shape != (1, 1) and right.shape == (1, 1): - @dataclassabc.dataclassabc(frozen=True) - class BroadCastedPolyMatrix(PolyMatrixMixin): - underlying_monomials: tuple[tuple[int], float] - shape: tuple[int, int] + # # @dataclassabc.dataclassabc(frozen=True) + # # class BroadCastedPolyMatrix(PolyMatrixMixin): + # # underlying_monomials: tuple[tuple[int], float] + # # shape: tuple[int, int] - def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: - return self.underlying_monomials + # # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: + # # return self.underlying_monomials - polynomial = right.get_poly(0, 0) - if polynomial is not None: + # right = BroadcastPolyMatrixImpl( + # polynomial=right.get_poly(0, 0), + # shape=left.shape, + # ) - broadcasted_right = BroadCastedPolyMatrix( - underlying_monomials=polynomial, - shape=left.shape, - ) + # # all_underlying = (left, broadcasted_right) - all_underlying = (left, broadcasted_right) + # else: + # if not (left.shape == right.shape): + # raise AssertionError(to_operator_exception( + # message=f'{left.shape} != {right.shape}', + # stack=self.stack, + # )) - else: - if not (left.shape == right.shape): - raise AssertionError(to_operator_exception( - message=f'{left.shape} != {right.shape}', - stack=self.stack, - )) + # # all_underlying = (left, right) - all_underlying = (left, right) + left, right = self.broadcast(left, right, self.stack) terms = {} @@ -76,7 +98,7 @@ class AdditionExprMixin(ExpressionBaseMixin): terms_row_col = {} - for underlying in all_underlying: + for underlying in (left, right): polynomial = underlying.get_poly(row, col) if polynomial is None: @@ -90,6 +112,7 @@ class AdditionExprMixin(ExpressionBaseMixin): if monomial not in terms_row_col: terms_row_col[monomial] = value + else: terms_row_col[monomial] += value |