diff options
Diffstat (limited to 'polymatrix/expression/mixins/additionexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/additionexprmixin.py | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index e247ebd..6257d55 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -3,11 +3,13 @@ import math import typing import dataclassabc +from polymatrix.utils.getstacklines import FrameSummary +from polymatrix.utils.tooperatorexception import to_operator_exception from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin class AdditionExprMixin(ExpressionBaseMixin): @@ -21,6 +23,11 @@ class AdditionExprMixin(ExpressionBaseMixin): def right(self) -> ExpressionBaseMixin: ... + @property + @abc.abstractmethod + def stack(self) -> tuple[FrameSummary]: + ... + # overwrites abstract method of `ExpressionBaseMixin` def apply( self, @@ -54,7 +61,11 @@ class AdditionExprMixin(ExpressionBaseMixin): all_underlying = (left, broadcasted_right) else: - assert left.shape == right.shape, f'{left.shape} != {right.shape}' + if not (left.shape == right.shape): + raise AssertionError(to_operator_exception( + message=f'{left.shape} != {right.shape}', + stack=self.stack, + )) all_underlying = (left, right) |