summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py85
1 files changed, 41 insertions, 44 deletions
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
index 9ea3529..cf4a674 100644
--- a/polymatrix/expression/mixins/expressionmixin.py
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -73,12 +73,6 @@ class ExpressionMixin(
),
)
- def __radd__(self, other):
- return self + other
-
- def __sub__(self, other):
- return self + other * (-1)
-
def __getattr__(self, name):
attr = getattr(self.underlying, name)
@@ -91,9 +85,16 @@ class ExpressionMixin(
else:
return attr
- def __mul__(self, other) -> 'ExpressionMixin':
- # assert isinstance(other, float)
+ def __getitem__(self, key: tuple[int, int]):
+ return dataclasses.replace(
+ self,
+ underlying=init_get_item_expr(
+ underlying=self.underlying,
+ index=key,
+ ),
+ )
+ def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin':
match other:
case ExpressionBaseMixin():
right = other.underlying
@@ -102,19 +103,15 @@ class ExpressionMixin(
return dataclasses.replace(
self,
- underlying=init_elem_mult_expr(
+ underlying=init_matrix_mult_expr(
left=self.underlying,
right=right,
),
)
- def __rmul__(self, other):
- return self * other
-
- def __neg__(self):
- return self * (-1)
+ def __mul__(self, other) -> 'ExpressionMixin':
+ # assert isinstance(other, float)
- def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin':
match other:
case ExpressionBaseMixin():
right = other.underlying
@@ -123,17 +120,29 @@ class ExpressionMixin(
return dataclasses.replace(
self,
- underlying=init_matrix_mult_expr(
+ underlying=init_elem_mult_expr(
left=self.underlying,
right=right,
),
)
+ def __neg__(self):
+ return self * (-1)
+
+ def __radd__(self, other):
+ return self + other
+
def __rmatmul__(self, other):
other = init_from_sympy_expr(other)
return other @ self
+ def __rmul__(self, other):
+ return self * other
+
+ def __sub__(self, other):
+ return self + other * (-1)
+
def __truediv__(self, other: ExpressionBaseMixin):
match other:
case ExpressionBaseMixin():
@@ -149,24 +158,6 @@ class ExpressionMixin(
),
)
- def __getitem__(self, key: tuple[int, int]):
- return dataclasses.replace(
- self,
- underlying=init_get_item_expr(
- underlying=self.underlying,
- index=key,
- ),
- )
-
- @property
- def T(self) -> 'ExpressionMixin':
- return dataclasses.replace(
- self,
- underlying=init_transpose_expr(
- underlying=self.underlying,
- ),
- )
-
def cache(self) -> 'ExpressionMixin':
return dataclasses.replace(
self,
@@ -291,15 +282,6 @@ class ExpressionMixin(
),
)
- def expressions_in(self, variables: tuple):
- return dataclasses.replace(
- self,
- underlying=init_linear_in_monomials_in(
- underlying=self.underlying,
- variables=variables,
- ),
- )
-
def max(self) -> 'ExpressionMixin':
return dataclasses.replace(
self,
@@ -438,6 +420,15 @@ class ExpressionMixin(
),
)
+ @property
+ def T(self) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_transpose_expr(
+ underlying=self.underlying,
+ ),
+ )
+
def to_constant(self) -> 'ExpressionMixin':
return dataclasses.replace(
self,
@@ -454,12 +445,18 @@ class ExpressionMixin(
),
)
- def truncate(self, variables: tuple, degrees: tuple[int]):
+ def truncate(
+ self,
+ variables: tuple,
+ degrees: tuple[int],
+ inverse: bool = None,
+ ):
return dataclasses.replace(
self,
underlying=init_truncate_expr(
underlying=self.underlying,
variables=variables,
degrees=degrees,
+ inverse=inverse,
),
)