diff options
-rw-r--r-- | polymatrix/expression/mixins/expressionmixin.py | 85 |
1 files changed, 41 insertions, 44 deletions
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py index 9ea3529..cf4a674 100644 --- a/polymatrix/expression/mixins/expressionmixin.py +++ b/polymatrix/expression/mixins/expressionmixin.py @@ -73,12 +73,6 @@ class ExpressionMixin( ), ) - def __radd__(self, other): - return self + other - - def __sub__(self, other): - return self + other * (-1) - def __getattr__(self, name): attr = getattr(self.underlying, name) @@ -91,9 +85,16 @@ class ExpressionMixin( else: return attr - def __mul__(self, other) -> 'ExpressionMixin': - # assert isinstance(other, float) + def __getitem__(self, key: tuple[int, int]): + return dataclasses.replace( + self, + underlying=init_get_item_expr( + underlying=self.underlying, + index=key, + ), + ) + def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin': match other: case ExpressionBaseMixin(): right = other.underlying @@ -102,19 +103,15 @@ class ExpressionMixin( return dataclasses.replace( self, - underlying=init_elem_mult_expr( + underlying=init_matrix_mult_expr( left=self.underlying, right=right, ), ) - def __rmul__(self, other): - return self * other - - def __neg__(self): - return self * (-1) + def __mul__(self, other) -> 'ExpressionMixin': + # assert isinstance(other, float) - def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin': match other: case ExpressionBaseMixin(): right = other.underlying @@ -123,17 +120,29 @@ class ExpressionMixin( return dataclasses.replace( self, - underlying=init_matrix_mult_expr( + underlying=init_elem_mult_expr( left=self.underlying, right=right, ), ) + def __neg__(self): + return self * (-1) + + def __radd__(self, other): + return self + other + def __rmatmul__(self, other): other = init_from_sympy_expr(other) return other @ self + def __rmul__(self, other): + return self * other + + def __sub__(self, other): + return self + other * (-1) + def __truediv__(self, other: ExpressionBaseMixin): match other: case ExpressionBaseMixin(): @@ -149,24 +158,6 @@ class ExpressionMixin( ), ) - def __getitem__(self, key: tuple[int, int]): - return dataclasses.replace( - self, - underlying=init_get_item_expr( - underlying=self.underlying, - index=key, - ), - ) - - @property - def T(self) -> 'ExpressionMixin': - return dataclasses.replace( - self, - underlying=init_transpose_expr( - underlying=self.underlying, - ), - ) - def cache(self) -> 'ExpressionMixin': return dataclasses.replace( self, @@ -291,15 +282,6 @@ class ExpressionMixin( ), ) - def expressions_in(self, variables: tuple): - return dataclasses.replace( - self, - underlying=init_linear_in_monomials_in( - underlying=self.underlying, - variables=variables, - ), - ) - def max(self) -> 'ExpressionMixin': return dataclasses.replace( self, @@ -438,6 +420,15 @@ class ExpressionMixin( ), ) + @property + def T(self) -> 'ExpressionMixin': + return dataclasses.replace( + self, + underlying=init_transpose_expr( + underlying=self.underlying, + ), + ) + def to_constant(self) -> 'ExpressionMixin': return dataclasses.replace( self, @@ -454,12 +445,18 @@ class ExpressionMixin( ), ) - def truncate(self, variables: tuple, degrees: tuple[int]): + def truncate( + self, + variables: tuple, + degrees: tuple[int], + inverse: bool = None, + ): return dataclasses.replace( self, underlying=init_truncate_expr( underlying=self.underlying, variables=variables, degrees=degrees, + inverse=inverse, ), ) |