diff options
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 44 |
1 files changed, 35 insertions, 9 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index d274f2a..2c7a652 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -5,6 +5,7 @@ import dataclass_abc from polymatrix.expression.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState @@ -20,24 +21,49 @@ class AdditionExprMixin(ExpressionBaseMixin): def right(self) -> ExpressionBaseMixin: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return self.left.shape - # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - assert left.shape == right.shape, f'{left.shape} != {right.shape}' - terms = {} - for underlying in (left, right): + if left.shape == (1, 1): + left, right = right, left + + if left.shape != (1, 1) and right.shape == (1, 1): + + @dataclass_abc.dataclass_abc(frozen=True) + class BroadCastedPolyMatrix(PolyMatrixMixin): + underlying_monomials: tuple[tuple[int], float] + shape: tuple[int, int] + + def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: + return self.underlying_monomials + + try: + underlying_terms = right.get_poly(0, 0) + + except KeyError: + pass + + else: + broadcasted_right = BroadCastedPolyMatrix( + underlying_monomials=underlying_terms, + shape=left.shape, + ) + + all_underlying = (left, broadcasted_right) + + else: + assert left.shape == right.shape, f'{left.shape} != {right.shape}' + + all_underlying = (left, right) + + for underlying in all_underlying: for row in range(left.shape[0]): for col in range(left.shape[1]): |