diff options
author | Nao Pross <np@0hm.ch> | 2024-05-27 11:02:01 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-27 11:02:01 +0200 |
commit | 8af3424abafa7c3cd521da79ad93bee368068732 (patch) | |
tree | b64cb57bc989d93ea7e54456c01838e117e95cbe | |
parent | Create NegationExpr for pretty printing (diff) | |
download | polymatrix-8af3424abafa7c3cd521da79ad93bee368068732.tar.gz polymatrix-8af3424abafa7c3cd521da79ad93bee368068732.zip |
Fix bug in Expression.__sub__, update other operator overloadings
Again, for pretty printing it is annyong to see that -2 * x becomes
x * (-2)
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/expression.py | 86 |
1 files changed, 41 insertions, 45 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 63bdf6e..3222034 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclassabc import dataclassabc from typing_extensions import override -import polymatrix.expression.init +import polymatrix.expression.init as init from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.variablemixin import VariableMixin @@ -43,7 +43,7 @@ class Expression(ExpressionBaseMixin, ABC): return self.apply(state)[1] def __add__(self, other: ExpressionBaseMixin) -> Expression: - return self._binary(polymatrix.expression.init.init_addition_expr, self, other) + return self._binary(init.init_addition_expr, self, other) def __getattr__(self, name): attr = getattr(self.underlying, name) @@ -59,7 +59,7 @@ class Expression(ExpressionBaseMixin, ABC): def __getitem__(self, slice: int | slice | tuple[int | slice, int | slice]) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_slice_expr( + underlying=init.init_slice_expr( underlying=self.underlying, slice=slice ), @@ -67,36 +67,32 @@ class Expression(ExpressionBaseMixin, ABC): def __matmul__(self, other: ExpressionBaseMixin | np.ndarray) -> Expression: return self._binary( - polymatrix.expression.init.init_matrix_mult_expr, self, other + init.init_matrix_mult_expr, self, other ) def __mul__(self, other) -> Expression: - return self._binary(polymatrix.expression.init.init_elem_mult_expr, self, other) + return self._binary(init.init_elem_mult_expr, self, other) def __pow__(self, exponent: Expression | int | float) -> Expression: - return self._binary(polymatrix.expression.init.init_power_expr, self, exponent) + return self._binary(init.init_power_expr, self, exponent) def __neg__(self): - return init_expression(polymatrix.expression.init.init_negation_expr(self.underlying)) + return init_expression(init.init_negation_expr(self.underlying)) def __radd__(self, other): - return self._binary(polymatrix.expression.init.init_addition_expr, other, self) + return self._binary(init.init_addition_expr, other, self) def __rmatmul__(self, other): - return self._binary( - polymatrix.expression.init.init_matrix_mult_expr, other, self - ) + return self._binary(init.init_matrix_mult_expr, other, self) def __rmul__(self, other): - return self * other + return self._binary(init.init_elem_mult_expr, other, self) def __rsub__(self, other): - return other + (-self) + return self._binary(init.init_subtraction_expr, other, self) def __sub__(self, other): - return self._binary( - polymatrix.expression.init.init_subtraction_expr, other, self - ) + return self._binary(init.init_subtraction_expr, self, other) def __truediv__(self, other): if not isinstance(other, float | int): @@ -115,7 +111,7 @@ class Expression(ExpressionBaseMixin, ABC): return left.copy(underlying=op(left.underlying, right.underlying, stack)) elif isinstance(left, Expression): - right = polymatrix.expression.init.init_from_expr_or_none(right) + right = init.init_from_expr_or_none(right) if right is None: return NotImplemented @@ -124,7 +120,7 @@ class Expression(ExpressionBaseMixin, ABC): # else right is an Expression else: - left = polymatrix.expression.init.init_from_expr_or_none(left) + left = init.init_from_expr_or_none(left) if left is None: return NotImplemented @@ -133,7 +129,7 @@ class Expression(ExpressionBaseMixin, ABC): def cache(self) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_cache_expr( + underlying=init.init_cache_expr( underlying=self.underlying, ), ) @@ -143,7 +139,7 @@ class Expression(ExpressionBaseMixin, ABC): degrees = degrees.underlying return self.copy( - underlying=polymatrix.expression.init.init_combinations_expr( + underlying=init.init_combinations_expr( expression=self.underlying, degrees=degrees, ), @@ -158,14 +154,14 @@ class Expression(ExpressionBaseMixin, ABC): def determinant(self) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_determinant_expr( + underlying=init.init_determinant_expr( underlying=self.underlying, ), ) def diag(self): return self.copy( - underlying=polymatrix.expression.init.init_diag_expr( + underlying=init.init_diag_expr( underlying=self.underlying, ), ) @@ -183,7 +179,7 @@ class Expression(ExpressionBaseMixin, ABC): def divergence(self, variables: tuple) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_divergence_expr( + underlying=init.init_divergence_expr( underlying=self.underlying, variables=variables, ), @@ -191,7 +187,7 @@ class Expression(ExpressionBaseMixin, ABC): def eval(self, variable: tuple, value: tuple[float, ...] | None = None,) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_eval_expr( + underlying=init.init_eval_expr( underlying=self.underlying, variables=variable, values=value, @@ -215,7 +211,7 @@ class Expression(ExpressionBaseMixin, ABC): # only applies to symmetric matrix def from_symmetric_matrix(self) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_from_symmetric_matrix_expr( + underlying=init.init_from_symmetric_matrix_expr( underlying=self.underlying, ), ) @@ -223,7 +219,7 @@ class Expression(ExpressionBaseMixin, ABC): # only applies to monomials def half_newton_polytope(self, variables: Expression, filter: Expression | None = None,) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_half_newton_polytope_expr( + underlying=init.init_half_newton_polytope_expr( monomials=self.underlying, variables=variables, filter=filter, @@ -247,7 +243,7 @@ class Expression(ExpressionBaseMixin, ABC): def linear_matrix_in(self, variable: Expression) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_linear_matrix_in_expr( + underlying=init.init_linear_matrix_in_expr( underlying=self.underlying, variable=variable, ), @@ -286,14 +282,14 @@ class Expression(ExpressionBaseMixin, ABC): def max(self) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_max_expr( + underlying=init.init_max_expr( underlying=self.underlying, ), ) def parametrize(self, name: str | None = None) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_parametrize_expr( + underlying=init.init_parametrize_expr( underlying=self.underlying, name=name, ), @@ -306,7 +302,7 @@ class Expression(ExpressionBaseMixin, ABC): stack = get_stack_lines() return self.copy( - underlying=polymatrix.expression.init.init_quadratic_in_expr( + underlying=init.init_quadratic_in_expr( underlying=self.underlying, monomials=monomials, variables=variables, @@ -319,7 +315,7 @@ class Expression(ExpressionBaseMixin, ABC): variables: "Expression", ) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_quadratic_monomials_expr( + underlying=init.init_quadratic_monomials_expr( underlying=self.underlying, variables=variables, ), @@ -327,7 +323,7 @@ class Expression(ExpressionBaseMixin, ABC): def reshape(self, n: int, m: int) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_reshape_expr( + underlying=init.init_reshape_expr( underlying=self.underlying, new_shape=(n, m), ), @@ -335,7 +331,7 @@ class Expression(ExpressionBaseMixin, ABC): def rep_mat(self, n: int, m: int) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_rep_mat_expr( + underlying=init.init_rep_mat_expr( underlying=self.underlying, repetition=(n, m), ), @@ -351,10 +347,10 @@ class Expression(ExpressionBaseMixin, ABC): if isinstance(value, Expression): value = value.underlying else: - value = polymatrix.expression.init.init_from_expr(value) + value = init.init_from_expr(value) return self.copy( - underlying=polymatrix.expression.init.init_set_element_at_expr( + underlying=init.init_set_element_at_expr( underlying=self.underlying, index=(row, col), value=value, @@ -363,14 +359,14 @@ class Expression(ExpressionBaseMixin, ABC): @property def shape(self) -> Expression: - return self.copy(underlying=polymatrix.expression.init.init_shape_expr(self.underlying)) + return self.copy(underlying=init.init_shape_expr(self.underlying)) # remove? def squeeze( self, ) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_squeeze_expr( + underlying=init.init_squeeze_expr( underlying=self.underlying, ), ) @@ -381,7 +377,7 @@ class Expression(ExpressionBaseMixin, ABC): monomials: "Expression", ) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_subtract_monomials_expr( + underlying=init.init_subtract_monomials_expr( underlying=self.underlying, monomials=monomials, ), @@ -389,21 +385,21 @@ class Expression(ExpressionBaseMixin, ABC): def sum(self): return self.copy( - underlying=polymatrix.expression.init.init_sum_expr( + underlying=init.init_sum_expr( underlying=self.underlying, ), ) def symmetric(self): return self.copy( - underlying=polymatrix.expression.init.init_symmetric_expr( + underlying=init.init_symmetric_expr( underlying=self.underlying, ), ) def transpose(self) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_transpose_expr( + underlying=init.init_transpose_expr( underlying=self.underlying, ), ) @@ -414,14 +410,14 @@ class Expression(ExpressionBaseMixin, ABC): def to_constant(self) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_to_constant_expr( + underlying=init.init_to_constant_expr( underlying=self.underlying, ), ) def to_symmetric_matrix(self) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_to_symmetric_matrix_expr( + underlying=init.init_to_symmetric_matrix_expr( underlying=self.underlying, ), ) @@ -429,7 +425,7 @@ class Expression(ExpressionBaseMixin, ABC): # only applies to variables def to_sorted_variables(self) -> "Expression": return self.copy( - underlying=polymatrix.expression.init.init_to_sorted_variables( + underlying=init.init_to_sorted_variables( underlying=self.underlying, ), ) @@ -442,7 +438,7 @@ class Expression(ExpressionBaseMixin, ABC): inverse: bool = None, ): return self.copy( - underlying=polymatrix.expression.init.init_truncate_expr( + underlying=init.init_truncate_expr( underlying=self.underlying, variables=variables, degrees=degrees, |