diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-06-13 16:06:25 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-06-13 16:06:25 +0200 |
commit | 68eb5f9e15be57a32317810237061b16c96a3271 (patch) | |
tree | e02cd2160964ecc7f92ee2a35e8342bbe1c8de6f /polymatrix/expression/mixins/additionexprmixin.py | |
parent | introduce state monad and functions to go along with it (diff) | |
download | polymatrix-68eb5f9e15be57a32317810237061b16c96a3271.tar.gz polymatrix-68eb5f9e15be57a32317810237061b16c96a3271.zip |
add eye, sum and symmetric operation
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 44 |
1 files changed, 35 insertions, 9 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index d274f2a..2c7a652 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -5,6 +5,7 @@ import dataclass_abc from polymatrix.expression.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState @@ -20,24 +21,49 @@ class AdditionExprMixin(ExpressionBaseMixin): def right(self) -> ExpressionBaseMixin: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return self.left.shape - # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - assert left.shape == right.shape, f'{left.shape} != {right.shape}' - terms = {} - for underlying in (left, right): + if left.shape == (1, 1): + left, right = right, left + + if left.shape != (1, 1) and right.shape == (1, 1): + + @dataclass_abc.dataclass_abc(frozen=True) + class BroadCastedPolyMatrix(PolyMatrixMixin): + underlying_monomials: tuple[tuple[int], float] + shape: tuple[int, int] + + def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: + return self.underlying_monomials + + try: + underlying_terms = right.get_poly(0, 0) + + except KeyError: + pass + + else: + broadcasted_right = BroadCastedPolyMatrix( + underlying_monomials=underlying_terms, + shape=left.shape, + ) + + all_underlying = (left, broadcasted_right) + + else: + assert left.shape == right.shape, f'{left.shape} != {right.shape}' + + all_underlying = (left, right) + + for underlying in all_underlying: for row in range(left.shape[0]): for col in range(left.shape[1]): |