diff options
-rw-r--r-- | polymatrix/__init__.py | 1 | ||||
-rw-r--r-- | polymatrix/expression/impl/tosortedvariablesimpl.py | 8 | ||||
-rw-r--r-- | polymatrix/expression/init/inittosortedvariables.py | 10 | ||||
-rw-r--r-- | polymatrix/expression/mixins/expressionmixin.py | 13 | ||||
-rw-r--r-- | polymatrix/expression/mixins/tosortedvariablesmixin.py | 33 | ||||
-rw-r--r-- | polymatrix/expression/tosortedvariables.py | 4 |
6 files changed, 67 insertions, 2 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index ad11c70..de3f9f3 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -483,6 +483,7 @@ def to_matrix_repr( continue for monomial, value in underlying_terms.items(): + def gen_new_monomial(): for var, count in monomial: try: diff --git a/polymatrix/expression/impl/tosortedvariablesimpl.py b/polymatrix/expression/impl/tosortedvariablesimpl.py new file mode 100644 index 0000000..249e55e --- /dev/null +++ b/polymatrix/expression/impl/tosortedvariablesimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.expression.tosortedvariables import ToSortedVariables + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class ToSortedVariablesImpl(ToSortedVariables): + underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/init/inittosortedvariables.py b/polymatrix/expression/init/inittosortedvariables.py new file mode 100644 index 0000000..5d669c9 --- /dev/null +++ b/polymatrix/expression/init/inittosortedvariables.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.tosortedvariablesimpl import ToSortedVariablesImpl + + +def init_to_sorted_variables( + underlying: ExpressionBaseMixin, +): + return ToSortedVariablesImpl( + underlying=underlying, +) diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py index b828d6a..622db34 100644 --- a/polymatrix/expression/mixins/expressionmixin.py +++ b/polymatrix/expression/mixins/expressionmixin.py @@ -36,6 +36,7 @@ from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr from polymatrix.expression.init.inittoconstantexpr import init_to_constant_expr from polymatrix.expression.init.inittoquadraticexpr import init_to_quadratic_expr from polymatrix.expression.init.initdiagexpr import init_diag_expr +from polymatrix.expression.init.inittosortedvariables import init_to_sorted_variables from polymatrix.expression.init.inittransposeexpr import init_transpose_expr from polymatrix.expression.init.inittruncateexpr import init_truncate_expr @@ -454,10 +455,18 @@ class ExpressionMixin( ), ) - def to_quadratic(self) -> 'ExpressionMixin': + # def to_quadratic(self) -> 'ExpressionMixin': + # return dataclasses.replace( + # self, + # underlying=init_to_quadratic_expr( + # underlying=self.underlying, + # ), + # ) + + def to_sorted_variables(self) -> 'ExpressionMixin': return dataclasses.replace( self, - underlying=init_to_quadratic_expr( + underlying=init_to_sorted_variables( underlying=self.underlying, ), ) diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py new file mode 100644 index 0000000..490604d --- /dev/null +++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py @@ -0,0 +1,33 @@ + +import abc +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable + +from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.polymatrix.polymatrix import PolyMatrix +from polymatrix.expressionstate.expressionstate import ExpressionState + + +class ToSortedVariablesMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + state, variable_indices = get_variable_indices_from_variable(state, self.underlying) + + def gen_sorted_vector(): + for row, index in enumerate(sorted(variable_indices)): + yield (row, 1), {(index, 1): 1} + + poly_matrix = init_poly_matrix( + terms=dict(gen_sorted_vector()), + shape=(len(variable_indices), 1), + ) + + return state, poly_matrix diff --git a/polymatrix/expression/tosortedvariables.py b/polymatrix/expression/tosortedvariables.py new file mode 100644 index 0000000..89ca315 --- /dev/null +++ b/polymatrix/expression/tosortedvariables.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.tosortedvariablesmixin import ToSortedVariablesMixin + +class ToSortedVariables(ToSortedVariablesMixin): + pass |