summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/additionexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py44
1 files changed, 35 insertions, 9 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index d274f2a..2c7a652 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -5,6 +5,7 @@ import dataclass_abc
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.expressionstate import ExpressionState
@@ -20,24 +21,49 @@ class AdditionExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.left.shape
-
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- assert left.shape == right.shape, f'{left.shape} != {right.shape}'
-
terms = {}
- for underlying in (left, right):
+ if left.shape == (1, 1):
+ left, right = right, left
+
+ if left.shape != (1, 1) and right.shape == (1, 1):
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class BroadCastedPolyMatrix(PolyMatrixMixin):
+ underlying_monomials: tuple[tuple[int], float]
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
+ return self.underlying_monomials
+
+ try:
+ underlying_terms = right.get_poly(0, 0)
+
+ except KeyError:
+ pass
+
+ else:
+ broadcasted_right = BroadCastedPolyMatrix(
+ underlying_monomials=underlying_terms,
+ shape=left.shape,
+ )
+
+ all_underlying = (left, broadcasted_right)
+
+ else:
+ assert left.shape == right.shape, f'{left.shape} != {right.shape}'
+
+ all_underlying = (left, right)
+
+ for underlying in all_underlying:
for row in range(left.shape[0]):
for col in range(left.shape[1]):