From f094e4d91b44fc1e8b5f11aac2dd8073ba024fc8 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 11 May 2024 19:32:34 +0200 Subject: Collapse ExpressionState to a single class ExpressionState has the usual mixin, whereby one splits the ABC into smaller parts called mixins, however this class only had one mixin and it makes no sense to have three classes instead of one. This class was simply not complex enough to make it worth maintaining all these files. I do not see a use case that would extend ExpressionState in a way that would introduce enough complexity for the mixin structure to be necessary. But it that were to happen in the future it is easy to reintroduce this structure I just deleted. In addition to reducing complexity this also fixes the inconsistent use of ExpressionStateMixin and ExpressionState across the mixin classes of Expression. --- polymatrix/__init__.py | 4 +- polymatrix/denserepr/from_.py | 2 +- polymatrix/denserepr/impl.py | 2 +- polymatrix/expression/expression.py | 12 +- polymatrix/expression/mixins/additionexprmixin.py | 2 +- polymatrix/expression/mixins/blockdiagexprmixin.py | 2 +- polymatrix/expression/mixins/cacheexprmixin.py | 6 +- .../expression/mixins/combinationsexprmixin.py | 2 +- polymatrix/expression/mixins/degreeexprmixin.py | 2 +- .../expression/mixins/derivativeexprmixin.py | 2 +- polymatrix/expression/mixins/diagexprmixin.py | 6 +- .../expression/mixins/divergenceexprmixin.py | 2 +- polymatrix/expression/mixins/elemmultexprmixin.py | 2 +- polymatrix/expression/mixins/evalexprmixin.py | 2 +- .../expression/mixins/expressionbasemixin.py | 4 +- polymatrix/expression/mixins/eyeexprmixin.py | 2 +- polymatrix/expression/mixins/filterexprmixin.py | 2 +- .../expression/mixins/filterlinearpartexprmixin.py | 2 +- .../expression/mixins/fromnumbersexprmixin.py | 4 +- polymatrix/expression/mixins/fromnumpyexprmixin.py | 4 +- polymatrix/expression/mixins/fromstatemonad.py | 4 +- .../mixins/fromsymmetricmatrixexprmixin.py | 8 +- polymatrix/expression/mixins/fromsympyexprmixin.py | 4 +- polymatrix/expression/mixins/fromtermsexprmixin.py | 6 +- polymatrix/expression/mixins/getitemexprmixin.py | 2 +- .../mixins/halfnewtonpolytopeexprmixin.py | 6 +- polymatrix/expression/mixins/integrateexprmixin.py | 7 +- .../expression/mixins/legendreseriesmixin.py | 2 +- polymatrix/expression/mixins/linearinexprmixin.py | 2 +- .../expression/mixins/linearmatrixinexprmixin.py | 2 +- .../expression/mixins/linearmonomialsexprmixin.py | 6 +- .../expression/mixins/matrixmultexprmixin.py | 2 +- polymatrix/expression/mixins/maxexprmixin.py | 6 +- .../expression/mixins/parametrizeexprmixin.py | 4 +- polymatrix/expression/mixins/powerexprmixin.py | 6 +- polymatrix/expression/mixins/productexprmixin.py | 2 +- .../expression/mixins/quadraticinexprmixin.py | 2 +- .../mixins/quadraticmonomialsexprmixin.py | 6 +- polymatrix/expression/mixins/repmatexprmixin.py | 6 +- polymatrix/expression/mixins/reshapeexprmixin.py | 7 +- .../expression/mixins/setelementatexprmixin.py | 2 +- polymatrix/expression/mixins/squeezeexprmixin.py | 2 +- .../expression/mixins/substituteexprmixin.py | 2 +- .../mixins/subtractmonomialsexprmixin.py | 6 +- polymatrix/expression/mixins/sumexprmixin.py | 6 +- polymatrix/expression/mixins/symmetricexprmixin.py | 4 +- .../expression/mixins/toconstantexprmixin.py | 2 +- .../expression/mixins/tosortedvariablesmixin.py | 2 +- .../mixins/tosymmetricmatrixexprmixin.py | 6 +- polymatrix/expression/mixins/transposeexprmixin.py | 2 +- polymatrix/expression/mixins/truncateexprmixin.py | 2 +- polymatrix/expression/mixins/variablemixin.py | 5 +- polymatrix/expression/mixins/vstackexprmixin.py | 2 +- polymatrix/expression/to.py | 2 +- .../expression/utils/getderivativemonomials.py | 2 +- polymatrix/expression/utils/getmonomialindices.py | 3 +- polymatrix/expression/utils/getvariableindices.py | 2 +- polymatrix/expressionstate.py | 181 +++++++++++++++++++++ polymatrix/expressionstate/__init__.py | 0 polymatrix/expressionstate/abc.py | 6 - polymatrix/expressionstate/impl.py | 10 -- polymatrix/expressionstate/init.py | 9 - polymatrix/expressionstate/mixins.py | 176 -------------------- 63 files changed, 282 insertions(+), 308 deletions(-) create mode 100644 polymatrix/expressionstate.py delete mode 100644 polymatrix/expressionstate/__init__.py delete mode 100644 polymatrix/expressionstate/abc.py delete mode 100644 polymatrix/expressionstate/impl.py delete mode 100644 polymatrix/expressionstate/init.py delete mode 100644 polymatrix/expressionstate/mixins.py diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 94da22a..62fc4a6 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -1,5 +1,5 @@ -from polymatrix.expressionstate.abc import ExpressionState as internal_ExpressionState -from polymatrix.expressionstate.init import ( +from polymatrix.expressionstate import ( + ExpressionState as internal_ExpressionState, init_expression_state as internal_init_expression_state, ) diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py index 1eda6fe..52040a1 100644 --- a/polymatrix/denserepr/from_.py +++ b/polymatrix/denserepr/from_.py @@ -3,7 +3,7 @@ import numpy as np from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.getvariableindices import ( diff --git a/polymatrix/denserepr/impl.py b/polymatrix/denserepr/impl.py index 9f53c8b..4d0835e 100644 --- a/polymatrix/denserepr/impl.py +++ b/polymatrix/denserepr/impl.py @@ -4,7 +4,7 @@ import typing import numpy as np import scipy.sparse -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 415f16f..2e7cf11 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -8,17 +8,15 @@ from abc import ABC, abstractmethod from dataclassabc import dataclassabc from typing_extensions import override - -from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.variable.abc import Variable import polymatrix.expression.init -from polymatrix.utils.getstacklines import get_stack_lines - -from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.utils.getstacklines import get_stack_lines +from polymatrix.variable.abc import Variable + from polymatrix.expression.op import ( diff, integrate, diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index a51bae5..4b0a9c9 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.index import PolyMatrixDict from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py index dece9c3..45e5e1b 100644 --- a/polymatrix/expression/mixins/blockdiagexprmixin.py +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -5,7 +5,7 @@ import dataclassabc from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py index 7947a74..c8a0a23 100644 --- a/polymatrix/expression/mixins/cacheexprmixin.py +++ b/polymatrix/expression/mixins/cacheexprmixin.py @@ -4,7 +4,7 @@ import dataclasses from polymatrix.polymatrix.mixins import PolyMatrixAsDictMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -18,8 +18,8 @@ class CacheExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: # FIXME: return type + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: # FIXME: return type if self in state.cache: return state, state.cache[self] diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py index 121d71c..abbb53d 100644 --- a/polymatrix/expression/mixins/combinationsexprmixin.py +++ b/polymatrix/expression/mixins/combinationsexprmixin.py @@ -5,7 +5,7 @@ from typing import Iterable from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/degreeexprmixin.py b/polymatrix/expression/mixins/degreeexprmixin.py index 522af7c..331c350 100644 --- a/polymatrix/expression/mixins/degreeexprmixin.py +++ b/polymatrix/expression/mixins/degreeexprmixin.py @@ -4,7 +4,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import MatrixIndex, PolyMatrixDict, PolyDict, MonomialIndex from polymatrix.utils.getstacklines import FrameSummary diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 31a1e31..dab1b41 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -4,7 +4,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py index d181543..b7666f0 100644 --- a/polymatrix/expression/mixins/diagexprmixin.py +++ b/polymatrix/expression/mixins/diagexprmixin.py @@ -2,7 +2,7 @@ import abc import dataclassabc from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -24,8 +24,8 @@ class DiagExprMixin(ExpressionBaseMixin): # FIXME: typing def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) # Vector to diagonal matrix diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py index 3477ee1..074caa8 100644 --- a/polymatrix/expression/mixins/divergenceexprmixin.py +++ b/polymatrix/expression/mixins/divergenceexprmixin.py @@ -5,7 +5,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index e914fd5..6ef97a3 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.index import MonomialIndex from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index 881ee83..103c556 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -5,7 +5,7 @@ import math from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py index 7f58339..dc8f699 100644 --- a/polymatrix/expression/mixins/expressionbasemixin.py +++ b/polymatrix/expression/mixins/expressionbasemixin.py @@ -1,9 +1,9 @@ import abc -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin class ExpressionBaseMixin(abc.ABC): @abc.abstractmethod - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: ... + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: ... diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py index 6a65585..74dd930 100644 --- a/polymatrix/expression/mixins/eyeexprmixin.py +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolyDict, MonomialIndex diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py index c8fa427..3cb6e5f 100644 --- a/polymatrix/expression/mixins/filterexprmixin.py +++ b/polymatrix/expression/mixins/filterexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class FilterExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py index 046f191..6edeb78 100644 --- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py +++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # is this class needed? diff --git a/polymatrix/expression/mixins/fromnumbersexprmixin.py b/polymatrix/expression/mixins/fromnumbersexprmixin.py index e34956b..d1c3ccf 100644 --- a/polymatrix/expression/mixins/fromnumbersexprmixin.py +++ b/polymatrix/expression/mixins/fromnumbersexprmixin.py @@ -2,7 +2,7 @@ from abc import abstractmethod from typing_extensions import override from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix @@ -24,7 +24,7 @@ class FromNumbersExprMixin(ExpressionBaseMixin): """ The matrix of numbers in row major order. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: p = PolyMatrixDict.empty() for r, row in enumerate(self.data): diff --git a/polymatrix/expression/mixins/fromnumpyexprmixin.py b/polymatrix/expression/mixins/fromnumpyexprmixin.py index cdf0aa6..cee5cc0 100644 --- a/polymatrix/expression/mixins/fromnumpyexprmixin.py +++ b/polymatrix/expression/mixins/fromnumpyexprmixin.py @@ -5,7 +5,7 @@ from typing_extensions import override, cast from numpy.typing import NDArray from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix @@ -26,7 +26,7 @@ class FromNumpyExprMixin(ExpressionBaseMixin): """ The Numpy array. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: p = PolyMatrixDict.empty() if len(self.data.shape) > 2: diff --git a/polymatrix/expression/mixins/fromstatemonad.py b/polymatrix/expression/mixins/fromstatemonad.py index 254d125..b7d9ea7 100644 --- a/polymatrix/expression/mixins/fromstatemonad.py +++ b/polymatrix/expression/mixins/fromstatemonad.py @@ -3,7 +3,7 @@ from typing_extensions import override from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.statemonad.abc import StateMonad @@ -23,7 +23,7 @@ class FromStateMonadMixin(ExpressionBaseMixin): """ The state monad object. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, expr = self.monad.apply(state) # Case when monad wraps function # f: ExpressionState -> (State, Expression) diff --git a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py index c7d90a1..55a64e5 100644 --- a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py @@ -6,7 +6,7 @@ from polymatrix.expression.utils.getvariableindices import get_variable_indices_ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -21,8 +21,8 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) @@ -43,4 +43,4 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin): shape=(var_index, 1), ) - return state, poly_matrix \ No newline at end of file + return state, poly_matrix diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index d178415..6d5446a 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -5,7 +5,7 @@ from abc import abstractmethod from typing_extensions import override, cast from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix @@ -36,7 +36,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): """ The sympy objects. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrix]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: # Unpack if it is a sympy matrix if isinstance(self.data, sympy.Matrix): diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py index 463f41c..8d340fd 100644 --- a/polymatrix/expression/mixins/fromtermsexprmixin.py +++ b/polymatrix/expression/mixins/fromtermsexprmixin.py @@ -2,7 +2,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import MonomialData @@ -26,8 +26,8 @@ class FromPolynomialDataExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: data = {coord: dict(polynomial) for coord, polynomial in self.data} poly_matrix = init_poly_matrix( diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py index affe4ec..6fac468 100644 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ b/polymatrix/expression/mixins/getitemexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class GetItemExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py index 52e5814..01b6c47 100644 --- a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py +++ b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py @@ -9,7 +9,7 @@ from polymatrix.expression.utils.getvariableindices import ( from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -29,8 +29,8 @@ class HalfNewtonPolytopeExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, sos_monomials = get_monomial_indices(state, self.monomials) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/integrateexprmixin.py b/polymatrix/expression/mixins/integrateexprmixin.py index 38cf522..5b1db7b 100644 --- a/polymatrix/expression/mixins/integrateexprmixin.py +++ b/polymatrix/expression/mixins/integrateexprmixin.py @@ -4,7 +4,7 @@ import itertools from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -41,10 +41,7 @@ class IntegrateExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) state, variables = get_variable_indices_from_variable(state, self.variables) diff --git a/polymatrix/expression/mixins/legendreseriesmixin.py b/polymatrix/expression/mixins/legendreseriesmixin.py index bec4c6c..d2c8372 100644 --- a/polymatrix/expression/mixins/legendreseriesmixin.py +++ b/polymatrix/expression/mixins/legendreseriesmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index 8354754..d9e8032 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py index 8239947..7ddc546 100644 --- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py +++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py @@ -5,7 +5,7 @@ from numpy import var from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py index c4a5df7..076cbb1 100644 --- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -40,8 +40,8 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state=state) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index 2d86619..10f5737 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -4,7 +4,7 @@ from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.tooperatorexception import to_operator_exception diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py index 650dc68..60c1832 100644 --- a/polymatrix/expression/mixins/maxexprmixin.py +++ b/polymatrix/expression/mixins/maxexprmixin.py @@ -2,7 +2,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -15,8 +15,8 @@ class MaxExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) poly_matrix_data = {} diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index ca49411..5ffecf3 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -4,7 +4,7 @@ from abc import abstractmethod from itertools import product from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolyMatrixDict, MatrixIndex, PolyDict, MonomialIndex, VariableIndex from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -36,7 +36,7 @@ class ParametrizeExprMixin(ExpressionBaseMixin): def name(self) -> str: ... # overwrites the abstract method of `ExpressionBaseMixin` - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) nrows, ncols = underlying.shape diff --git a/polymatrix/expression/mixins/powerexprmixin.py b/polymatrix/expression/mixins/powerexprmixin.py index 4ca1b0b..6d6aae4 100644 --- a/polymatrix/expression/mixins/powerexprmixin.py +++ b/polymatrix/expression/mixins/powerexprmixin.py @@ -7,7 +7,7 @@ from typing_extensions import override from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -30,10 +30,10 @@ class PowerExprMixin(ExpressionBaseMixin): @staticmethod def power( - state: ExpressionStateMixin, + state: ExpressionState, left: ExpressionBaseMixin, right: ExpressionBaseMixin | int | float - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + ) -> tuple[ExpressionState, PolyMatrixMixin]: """ Compute the expression ```left ** right```. """ exponent: int | None = None diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py index 12ea4f2..9db1b9c 100644 --- a/polymatrix/expression/mixins/productexprmixin.py +++ b/polymatrix/expression/mixins/productexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index 144fbab..cba2844 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py index 2717e7e..3ce3b85 100644 --- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -41,8 +41,8 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state=state) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py index e3c6498..7ef01ae 100644 --- a/polymatrix/expression/mixins/repmatexprmixin.py +++ b/polymatrix/expression/mixins/repmatexprmixin.py @@ -2,7 +2,7 @@ import abc import dataclassabc from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -19,8 +19,8 @@ class RepMatExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) # FIXME: move to polymatrix module diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py index 9f43c04..23f295b 100644 --- a/polymatrix/expression/mixins/reshapeexprmixin.py +++ b/polymatrix/expression/mixins/reshapeexprmixin.py @@ -6,7 +6,7 @@ import dataclassabc import numpy as np from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -25,10 +25,11 @@ class ReshapeExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) + # TODO: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class ReshapePolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py index 61db750..35c99d8 100644 --- a/polymatrix/expression/mixins/setelementatexprmixin.py +++ b/polymatrix/expression/mixins/setelementatexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class SetElementAtExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py index 67ade46..785647c 100644 --- a/polymatrix/expression/mixins/squeezeexprmixin.py +++ b/polymatrix/expression/mixins/squeezeexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # remove? diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index d2ab740..a3e76b1 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -7,7 +7,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py index 6bac539..3efb656 100644 --- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.polymatrix.utils.sortmonomials import sort_monomials @@ -29,8 +29,8 @@ class SubtractMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = get_monomial_indices(state, self.underlying) state, sub_monomials = get_monomial_indices(state, self.monomials) diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py index 4ade032..af722a4 100644 --- a/polymatrix/expression/mixins/sumexprmixin.py +++ b/polymatrix/expression/mixins/sumexprmixin.py @@ -4,7 +4,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -22,8 +22,8 @@ class SumExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) poly_matrix_data = collections.defaultdict(dict) diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py index ea9d16f..dd77bef 100644 --- a/polymatrix/expression/mixins/symmetricexprmixin.py +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -1,15 +1,13 @@ import abc import collections -import itertools import dataclassabc -import typing from polymatrix.polymatrix.index import PolyDict from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class SymmetricExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py index 9c8b9b7..51c2e56 100644 --- a/polymatrix/expression/mixins/toconstantexprmixin.py +++ b/polymatrix/expression/mixins/toconstantexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class ToConstantExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py index 4dfa518..865fdf5 100644 --- a/polymatrix/expression/mixins/tosortedvariablesmixin.py +++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py @@ -6,7 +6,7 @@ from polymatrix.expression.utils.getvariableindices import ( from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # to be deleted? diff --git a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py index 3c50f76..65dfffd 100644 --- a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.expression.utils.getvariableindices import ( ) from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -23,8 +23,8 @@ class ToSymmetricMatrixExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) assert underlying.shape[1] == 1 diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py index eaab11a..16139ba 100644 --- a/polymatrix/expression/mixins/transposeexprmixin.py +++ b/polymatrix/expression/mixins/transposeexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.index import PolyDict from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class TransposeExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py index 3171221..629dd45 100644 --- a/polymatrix/expression/mixins/truncateexprmixin.py +++ b/polymatrix/expression/mixins/truncateexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py index 750e320..fe8ad80 100644 --- a/polymatrix/expression/mixins/variablemixin.py +++ b/polymatrix/expression/mixins/variablemixin.py @@ -4,9 +4,8 @@ import typing import itertools from typing_extensions import override -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex from polymatrix.variable.abc import Variable @@ -16,7 +15,7 @@ class VariableMixin(ExpressionBaseMixin, Variable): """ Underlying object for VariableExpression """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state = state.register(self) indices = state.get_indices(self) p = PolyMatrixDict() diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index 8f17b2f..249bb26 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -6,7 +6,7 @@ from polymatrix.polymatrix.index import PolyDict from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class VStackExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index c24b1d2..4509d1e 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -3,7 +3,7 @@ import sympy import numpy as np from polymatrix.expression.expression import Expression -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin diff --git a/polymatrix/expression/utils/getderivativemonomials.py b/polymatrix/expression/utils/getderivativemonomials.py index cf2ae92..de783d9 100644 --- a/polymatrix/expression/utils/getderivativemonomials.py +++ b/polymatrix/expression/utils/getderivativemonomials.py @@ -2,7 +2,7 @@ import collections import dataclasses import itertools -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolynomialData diff --git a/polymatrix/expression/utils/getmonomialindices.py b/polymatrix/expression/utils/getmonomialindices.py index 9145002..6470b5e 100644 --- a/polymatrix/expression/utils/getmonomialindices.py +++ b/polymatrix/expression/utils/getmonomialindices.py @@ -1,7 +1,8 @@ -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +# TODO: mark as deprecated, explain replacement def get_monomial_indices( state: ExpressionState, expression: ExpressionBaseMixin, diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index d87f82e..13b6aa0 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -1,7 +1,7 @@ import itertools import typing -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.utils.deprecation import deprecated diff --git a/polymatrix/expressionstate.py b/polymatrix/expressionstate.py new file mode 100644 index 0000000..931f282 --- /dev/null +++ b/polymatrix/expressionstate.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, NamedTuple, Iterable +from math import prod +from dataclassabc import dataclassabc +from dataclasses import replace + +from polymatrix.variable.abc import Variable +from polymatrix.utils.deprecation import deprecated +from polymatrix.polymatrix.index import MonomialIndex, VariableIndex + +from polymatrix.statemonad.mixins import StateCacheMixin + +# TODO: move to typing submodule +class IndexRange(NamedTuple): + start: int + end: int + + def __lt__(self, other): + return self.start < other.start + + +@dataclassabc(frozen=True) +class ExpressionState(StateCacheMixin): + n_variables: int + """ Number of polynomial variables """ + + indices: dict[Variable, IndexRange] + """ Map from variable objects to their indices. """ + + cache: dict + """ Cache for StateCacheMixin """ + + def index(self, var: Variable) -> tuple[ExpressionState, IndexRange]: + """ Index a variable and get its index range. """ + if not isinstance(var, Variable): + raise ValueError("State can only index object of type Variable!") + + for v, irange in self.indices.items(): + # Check if already in there + if v == var: + return self, irange + + # Check that there is not another variable with the same name + if v.name == var.name: + raise ValueError("Variable must have unique names! " + f"There is already a variable named {var.name} " + f"with shape {v.shape}") + + # If not save new index + size = prod(var.shape) + index = IndexRange(start=self.n_variables, end=self.n_variables + size) + + return replace( + self, + n_variables=self.n_variables + size, + indices=self.indices | {var: index} + ), index + + def register(self, var: Variable) -> ExpressionStateMixin: + """ + Create an index for a variable, but does not return the index. If you + want the index range use :py:meth:`index` + """ + state, _ = self.index(var) + return state + + def get_indices(self, var: Variable) -> Iterable[int]: + """ + Get all indices associated to a variable. + + When a variable is not a scalar multiple indices will be associated to + the variables, one for each entry. + + See also :py:meth:`get_variable_indices`, :py:meth:`get_monomial_indices`. + """ + if var not in self.indices: + raise IndexError(f"There is no variable {var} in this state object.") + + yield from range(*self.indices[var]) + + def get_indices_as_variable_index(self, var: Variable) -> Iterable[VariableIndex]: + """ + Get all indices associated to a variable, wrapped in a `VariableIndex`. + + See also :py:meth:`get_indices`, + :py:class:`polymatrix.polymatrix.index.VariableIndex`. + """ + yield from (VariableIndex(index=i, power=1) + for i in self.get_indices(var)) + + def get_indices_as_monomial_index(self, var: Variable) -> Iterable[MonomialIndex]: + """ + Get all indices associated to a variable, wrapped in a `MonomialIndex`. + + See also :py:meth:`get_indices`, + :py:class:`polymatrix.polymatrix.index.MonomialIndex`. + """ + yield from (MonomialIndex((v,)) + for v in self.get_indices_as_variable_index(var)) + + def get_variable(self, index: int) -> Variable: + """ Get the variable object from its index. """ + for variable, (start, end) in self.indices.items(): + if start <= index < end: + return variable + + raise IndexError(f"There is no variable with index {index}.") + + def get_variable_from_variable_index(self, var: VariableIndex) -> Variable: + """ Get the variable object from the index contained in a `VariableIndex` """ + return self.get_variable(var.index) + + def get_variables_from_monomial_index(self, monomial: MonomialIndex) -> Iterable[Variable]: + """ Get all variable objects from the indices contained in a `MonomialIndex` """ + for v in monomial: + # FIXME: non-scalar variable will be yielded multiple times + yield self.get_variable_from_variable_index(v) + + def get_variable_from_name_or(self, name: str, if_not_present: Any) -> Variable | Any: + """ + Get a variable object given its name, or if there is no variable with + the given name return what is passed in the `if_not_present` argument. + """ + for v in self.indices.keys(): + if v.name == name: + return name + + return if_not_present + + def get_variable_from_name(self, name: str) -> Variable: + """ + Get a variable object given its name, raises KeyError if there is no + variable with the given name. + """ + if v := self.get_variable_from_name_or(name, False): + return v + + raise KeyError(f"There is no variable named {name}") + + def get_name(self, index: int) -> str: + """ Get the name of a variable given its index. """ + for variable, (start, end) in self.indices.items(): + if start <= index < end: + # Variable is not scalar + if end - start > 1: + return f"{variable.name}_{index - start}" + + return variable.name + + raise IndexError(f"There is no variable with index {index}.") + + # -- Old API --- + + @property + @deprecated("replaced by n_variables") + def n_param(self) -> int: + return self.n_variables + + @property + @deprecated("replaced by indices") + def offset_dict(self): + return self.indices + + @property + @deprecated("Support for auxillary equations was removed") + def auxillary_equations(self): + return {} + + @deprecated("replaced by get_variable") + def get_key_from_offset(self, index: int) -> Variable: + return self.get_variable(index) + + +def init_expression_state(n_variables: int = 0, indices: dict[Variable, IndexRange] = {}): + return ExpressionState( + n_variables=n_variables, + indices=indices, + cache={}, + ) diff --git a/polymatrix/expressionstate/__init__.py b/polymatrix/expressionstate/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/polymatrix/expressionstate/abc.py b/polymatrix/expressionstate/abc.py deleted file mode 100644 index bdc4484..0000000 --- a/polymatrix/expressionstate/abc.py +++ /dev/null @@ -1,6 +0,0 @@ -from polymatrix.expressionstate.mixins import ExpressionStateMixin - - -# NP: "state" of an expression that maps indices to variable / parameter objects -class ExpressionState(ExpressionStateMixin): - pass diff --git a/polymatrix/expressionstate/impl.py b/polymatrix/expressionstate/impl.py deleted file mode 100644 index c243a26..0000000 --- a/polymatrix/expressionstate/impl.py +++ /dev/null @@ -1,10 +0,0 @@ -import dataclassabc - -from polymatrix.expressionstate.abc import ExpressionState - - -@dataclassabc.dataclassabc(frozen=True) -class ExpressionStateImpl(ExpressionState): - n_variables: int - indices: dict - cache: dict diff --git a/polymatrix/expressionstate/init.py b/polymatrix/expressionstate/init.py deleted file mode 100644 index d7c7d25..0000000 --- a/polymatrix/expressionstate/init.py +++ /dev/null @@ -1,9 +0,0 @@ -from polymatrix.expressionstate.impl import ExpressionStateImpl - - -def init_expression_state(n_param: int = 0, offset_dict: dict = {}): - return ExpressionStateImpl( - n_variables=n_param, - indices={}, - cache={}, - ) diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py deleted file mode 100644 index 8d1145f..0000000 --- a/polymatrix/expressionstate/mixins.py +++ /dev/null @@ -1,176 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from abc import abstractmethod -from typing import NamedTuple, Iterable -from math import prod -import dataclasses - -from polymatrix.variable.abc import Variable -from polymatrix.utils.deprecation import deprecated -from polymatrix.polymatrix.index import MonomialIndex, VariableIndex - -from polymatrix.statemonad.mixins import StateCacheMixin - -# TODO: move to typing submodule -class IndexRange(NamedTuple): - start: int - end: int - - def __lt__(self, other): - return self.start < other.start - - -class ExpressionStateMixin( - StateCacheMixin, -): - @property - @abstractmethod - def n_variables(self) -> int: - """ Number of polynomial variables """ - - @property - @abstractmethod - def indices(self) -> dict[Variable, IndexRange]: - """ Map from variable objects to their indices. """ - - def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]: - """ Index a variable and get its index range. """ - if not isinstance(var, Variable): - raise ValueError("State can only index object of type Variable!") - - for v, irange in self.indices.items(): - # Check if already in there - if v == var: - return self, irange - - # Check that there is not another variable with the same name - if v.name == var.name: - raise ValueError("Variable must have unique names! " - f"There is already a variable named {var.name} " - f"with shape {v.shape}") - - # If not save new index - size = prod(var.shape) - index = IndexRange(start=self.n_variables, end=self.n_variables + size) - - return dataclasses.replace( - self, - n_variables=self.n_variables + size, - indices=self.indices | {var: index} - ), index - - def register(self, var: Variable) -> ExpressionStateMixin: - """ - Create an index for a variable, but does not return the index. If you - want the index range use :py:meth:`index` - """ - state, _ = self.index(var) - return state - - def get_indices(self, var: Variable) -> Iterable[int]: - """ - Get all indices associated to a variable. - - When a variable is not a scalar multiple indices will be associated to - the variables, one for each entry. - - See also :py:meth:`get_variable_indices`, :py:meth:`get_monomial_indices`. - """ - if var not in self.indices: - raise IndexError(f"There is no variable {var} in this state object.") - - yield from range(*self.indices[var]) - - def get_indices_as_variable_index(self, var: Variable) -> Iterable[VariableIndex]: - """ - Get all indices associated to a variable, wrapped in a `VariableIndex`. - - See also :py:meth:`get_indices`, - :py:class:`polymatrix.polymatrix.index.VariableIndex`. - """ - yield from (VariableIndex(index=i, power=1) - for i in self.get_indices(var)) - - def get_indices_as_monomial_index(self, var: Variable) -> Iterable[MonomialIndex]: - """ - Get all indices associated to a variable, wrapped in a `MonomialIndex`. - - See also :py:meth:`get_indices`, - :py:class:`polymatrix.polymatrix.index.MonomialIndex`. - """ - yield from (MonomialIndex((v,)) - for v in self.get_indices_as_variable_index(var)) - - def get_variable(self, index: int) -> Variable: - """ Get the variable object from its index. """ - for variable, (start, end) in self.indices.items(): - if start <= index < end: - return variable - - raise IndexError(f"There is no variable with index {index}.") - - def get_variable_from_variable_index(self, var: VariableIndex) -> Variable: - """ Get the variable object from the index contained in a `VariableIndex` """ - return self.get_variable(var.index) - - def get_variables_from_monomial_index(self, monomial: MonomialIndex) -> Iterable[Variable]: - """ Get all variable objects from the indices contained in a `MonomialIndex` """ - for v in monomial: - # FIXME: non-scalar variable will be yielded multiple times - yield self.get_variable_from_variable_index(v) - - def get_variable_from_name_or(self, name: str, if_not_present: Any) -> Variable | Any: - """ - Get a variable object given its name, or if there is no variable with - the given name return what is passed in the `if_not_present` argument. - """ - for v in self.indices.keys(): - if v.name == name: - return name - - return if_not_present - - def get_variable_from_name(self, name: str) -> Variable: - """ - Get a variable object given its name, raises KeyError if there is no - variable with the given name. - """ - if v := self.get_variable_from_name_or(name, False): - return v - - raise KeyError(f"There is no variable named {name}") - - def get_name(self, index: int) -> str: - """ Get the name of a variable given its index. """ - for variable, (start, end) in self.indices.items(): - if start <= index < end: - # Variable is not scalar - if end - start > 1: - return f"{variable.name}_{index - start}" - - return variable.name - - raise IndexError(f"There is no variable with index {index}.") - - # -- Old API --- - - @property - @deprecated("replaced by n_variables") - def n_param(self) -> int: - return self.n_variables - - @property - @deprecated("replaced by indices") - def offset_dict(self): - return self.indices - - @property - @deprecated("Support for auxillary equations was removed") - def auxillary_equations(self): - return {} - - @deprecated("replaced by get_variable") - def get_key_from_offset(self, index: int) -> Variable: - return self.get_variable(index) - -- cgit v1.2.1