summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-27 11:02:01 +0200
committerNao Pross <np@0hm.ch>2024-05-27 11:02:01 +0200
commit8af3424abafa7c3cd521da79ad93bee368068732 (patch)
treeb64cb57bc989d93ea7e54456c01838e117e95cbe
parentCreate NegationExpr for pretty printing (diff)
downloadpolymatrix-8af3424abafa7c3cd521da79ad93bee368068732.tar.gz
polymatrix-8af3424abafa7c3cd521da79ad93bee368068732.zip
Fix bug in Expression.__sub__, update other operator overloadings
Again, for pretty printing it is annyong to see that -2 * x becomes x * (-2)
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/expression.py86
1 files changed, 41 insertions, 45 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 63bdf6e..3222034 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from dataclassabc import dataclassabc
from typing_extensions import override
-import polymatrix.expression.init
+import polymatrix.expression.init as init
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.mixins.variablemixin import VariableMixin
@@ -43,7 +43,7 @@ class Expression(ExpressionBaseMixin, ABC):
return self.apply(state)[1]
def __add__(self, other: ExpressionBaseMixin) -> Expression:
- return self._binary(polymatrix.expression.init.init_addition_expr, self, other)
+ return self._binary(init.init_addition_expr, self, other)
def __getattr__(self, name):
attr = getattr(self.underlying, name)
@@ -59,7 +59,7 @@ class Expression(ExpressionBaseMixin, ABC):
def __getitem__(self, slice: int | slice | tuple[int | slice, int | slice]) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_slice_expr(
+ underlying=init.init_slice_expr(
underlying=self.underlying,
slice=slice
),
@@ -67,36 +67,32 @@ class Expression(ExpressionBaseMixin, ABC):
def __matmul__(self, other: ExpressionBaseMixin | np.ndarray) -> Expression:
return self._binary(
- polymatrix.expression.init.init_matrix_mult_expr, self, other
+ init.init_matrix_mult_expr, self, other
)
def __mul__(self, other) -> Expression:
- return self._binary(polymatrix.expression.init.init_elem_mult_expr, self, other)
+ return self._binary(init.init_elem_mult_expr, self, other)
def __pow__(self, exponent: Expression | int | float) -> Expression:
- return self._binary(polymatrix.expression.init.init_power_expr, self, exponent)
+ return self._binary(init.init_power_expr, self, exponent)
def __neg__(self):
- return init_expression(polymatrix.expression.init.init_negation_expr(self.underlying))
+ return init_expression(init.init_negation_expr(self.underlying))
def __radd__(self, other):
- return self._binary(polymatrix.expression.init.init_addition_expr, other, self)
+ return self._binary(init.init_addition_expr, other, self)
def __rmatmul__(self, other):
- return self._binary(
- polymatrix.expression.init.init_matrix_mult_expr, other, self
- )
+ return self._binary(init.init_matrix_mult_expr, other, self)
def __rmul__(self, other):
- return self * other
+ return self._binary(init.init_elem_mult_expr, other, self)
def __rsub__(self, other):
- return other + (-self)
+ return self._binary(init.init_subtraction_expr, other, self)
def __sub__(self, other):
- return self._binary(
- polymatrix.expression.init.init_subtraction_expr, other, self
- )
+ return self._binary(init.init_subtraction_expr, self, other)
def __truediv__(self, other):
if not isinstance(other, float | int):
@@ -115,7 +111,7 @@ class Expression(ExpressionBaseMixin, ABC):
return left.copy(underlying=op(left.underlying, right.underlying, stack))
elif isinstance(left, Expression):
- right = polymatrix.expression.init.init_from_expr_or_none(right)
+ right = init.init_from_expr_or_none(right)
if right is None:
return NotImplemented
@@ -124,7 +120,7 @@ class Expression(ExpressionBaseMixin, ABC):
# else right is an Expression
else:
- left = polymatrix.expression.init.init_from_expr_or_none(left)
+ left = init.init_from_expr_or_none(left)
if left is None:
return NotImplemented
@@ -133,7 +129,7 @@ class Expression(ExpressionBaseMixin, ABC):
def cache(self) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_cache_expr(
+ underlying=init.init_cache_expr(
underlying=self.underlying,
),
)
@@ -143,7 +139,7 @@ class Expression(ExpressionBaseMixin, ABC):
degrees = degrees.underlying
return self.copy(
- underlying=polymatrix.expression.init.init_combinations_expr(
+ underlying=init.init_combinations_expr(
expression=self.underlying,
degrees=degrees,
),
@@ -158,14 +154,14 @@ class Expression(ExpressionBaseMixin, ABC):
def determinant(self) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_determinant_expr(
+ underlying=init.init_determinant_expr(
underlying=self.underlying,
),
)
def diag(self):
return self.copy(
- underlying=polymatrix.expression.init.init_diag_expr(
+ underlying=init.init_diag_expr(
underlying=self.underlying,
),
)
@@ -183,7 +179,7 @@ class Expression(ExpressionBaseMixin, ABC):
def divergence(self, variables: tuple) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_divergence_expr(
+ underlying=init.init_divergence_expr(
underlying=self.underlying,
variables=variables,
),
@@ -191,7 +187,7 @@ class Expression(ExpressionBaseMixin, ABC):
def eval(self, variable: tuple, value: tuple[float, ...] | None = None,) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_eval_expr(
+ underlying=init.init_eval_expr(
underlying=self.underlying,
variables=variable,
values=value,
@@ -215,7 +211,7 @@ class Expression(ExpressionBaseMixin, ABC):
# only applies to symmetric matrix
def from_symmetric_matrix(self) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_from_symmetric_matrix_expr(
+ underlying=init.init_from_symmetric_matrix_expr(
underlying=self.underlying,
),
)
@@ -223,7 +219,7 @@ class Expression(ExpressionBaseMixin, ABC):
# only applies to monomials
def half_newton_polytope(self, variables: Expression, filter: Expression | None = None,) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_half_newton_polytope_expr(
+ underlying=init.init_half_newton_polytope_expr(
monomials=self.underlying,
variables=variables,
filter=filter,
@@ -247,7 +243,7 @@ class Expression(ExpressionBaseMixin, ABC):
def linear_matrix_in(self, variable: Expression) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_linear_matrix_in_expr(
+ underlying=init.init_linear_matrix_in_expr(
underlying=self.underlying,
variable=variable,
),
@@ -286,14 +282,14 @@ class Expression(ExpressionBaseMixin, ABC):
def max(self) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_max_expr(
+ underlying=init.init_max_expr(
underlying=self.underlying,
),
)
def parametrize(self, name: str | None = None) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_parametrize_expr(
+ underlying=init.init_parametrize_expr(
underlying=self.underlying,
name=name,
),
@@ -306,7 +302,7 @@ class Expression(ExpressionBaseMixin, ABC):
stack = get_stack_lines()
return self.copy(
- underlying=polymatrix.expression.init.init_quadratic_in_expr(
+ underlying=init.init_quadratic_in_expr(
underlying=self.underlying,
monomials=monomials,
variables=variables,
@@ -319,7 +315,7 @@ class Expression(ExpressionBaseMixin, ABC):
variables: "Expression",
) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_quadratic_monomials_expr(
+ underlying=init.init_quadratic_monomials_expr(
underlying=self.underlying,
variables=variables,
),
@@ -327,7 +323,7 @@ class Expression(ExpressionBaseMixin, ABC):
def reshape(self, n: int, m: int) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_reshape_expr(
+ underlying=init.init_reshape_expr(
underlying=self.underlying,
new_shape=(n, m),
),
@@ -335,7 +331,7 @@ class Expression(ExpressionBaseMixin, ABC):
def rep_mat(self, n: int, m: int) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_rep_mat_expr(
+ underlying=init.init_rep_mat_expr(
underlying=self.underlying,
repetition=(n, m),
),
@@ -351,10 +347,10 @@ class Expression(ExpressionBaseMixin, ABC):
if isinstance(value, Expression):
value = value.underlying
else:
- value = polymatrix.expression.init.init_from_expr(value)
+ value = init.init_from_expr(value)
return self.copy(
- underlying=polymatrix.expression.init.init_set_element_at_expr(
+ underlying=init.init_set_element_at_expr(
underlying=self.underlying,
index=(row, col),
value=value,
@@ -363,14 +359,14 @@ class Expression(ExpressionBaseMixin, ABC):
@property
def shape(self) -> Expression:
- return self.copy(underlying=polymatrix.expression.init.init_shape_expr(self.underlying))
+ return self.copy(underlying=init.init_shape_expr(self.underlying))
# remove?
def squeeze(
self,
) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_squeeze_expr(
+ underlying=init.init_squeeze_expr(
underlying=self.underlying,
),
)
@@ -381,7 +377,7 @@ class Expression(ExpressionBaseMixin, ABC):
monomials: "Expression",
) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_subtract_monomials_expr(
+ underlying=init.init_subtract_monomials_expr(
underlying=self.underlying,
monomials=monomials,
),
@@ -389,21 +385,21 @@ class Expression(ExpressionBaseMixin, ABC):
def sum(self):
return self.copy(
- underlying=polymatrix.expression.init.init_sum_expr(
+ underlying=init.init_sum_expr(
underlying=self.underlying,
),
)
def symmetric(self):
return self.copy(
- underlying=polymatrix.expression.init.init_symmetric_expr(
+ underlying=init.init_symmetric_expr(
underlying=self.underlying,
),
)
def transpose(self) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_transpose_expr(
+ underlying=init.init_transpose_expr(
underlying=self.underlying,
),
)
@@ -414,14 +410,14 @@ class Expression(ExpressionBaseMixin, ABC):
def to_constant(self) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_to_constant_expr(
+ underlying=init.init_to_constant_expr(
underlying=self.underlying,
),
)
def to_symmetric_matrix(self) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_to_symmetric_matrix_expr(
+ underlying=init.init_to_symmetric_matrix_expr(
underlying=self.underlying,
),
)
@@ -429,7 +425,7 @@ class Expression(ExpressionBaseMixin, ABC):
# only applies to variables
def to_sorted_variables(self) -> "Expression":
return self.copy(
- underlying=polymatrix.expression.init.init_to_sorted_variables(
+ underlying=init.init_to_sorted_variables(
underlying=self.underlying,
),
)
@@ -442,7 +438,7 @@ class Expression(ExpressionBaseMixin, ABC):
inverse: bool = None,
):
return self.copy(
- underlying=polymatrix.expression.init.init_truncate_expr(
+ underlying=init.init_truncate_expr(
underlying=self.underlying,
variables=variables,
degrees=degrees,