diff options
62 files changed, 122 insertions, 148 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 94da22a..62fc4a6 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -1,5 +1,5 @@ -from polymatrix.expressionstate.abc import ExpressionState as internal_ExpressionState -from polymatrix.expressionstate.init import ( +from polymatrix.expressionstate import ( + ExpressionState as internal_ExpressionState, init_expression_state as internal_init_expression_state, ) diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py index 1eda6fe..52040a1 100644 --- a/polymatrix/denserepr/from_.py +++ b/polymatrix/denserepr/from_.py @@ -3,7 +3,7 @@ import numpy as np from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.getvariableindices import ( diff --git a/polymatrix/denserepr/impl.py b/polymatrix/denserepr/impl.py index 9f53c8b..4d0835e 100644 --- a/polymatrix/denserepr/impl.py +++ b/polymatrix/denserepr/impl.py @@ -4,7 +4,7 @@ import typing import numpy as np import scipy.sparse -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 415f16f..2e7cf11 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -8,17 +8,15 @@ from abc import ABC, abstractmethod from dataclassabc import dataclassabc from typing_extensions import override - -from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.variable.abc import Variable import polymatrix.expression.init -from polymatrix.utils.getstacklines import get_stack_lines - -from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.utils.getstacklines import get_stack_lines +from polymatrix.variable.abc import Variable + from polymatrix.expression.op import ( diff, integrate, diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index a51bae5..4b0a9c9 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.index import PolyMatrixDict from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py index dece9c3..45e5e1b 100644 --- a/polymatrix/expression/mixins/blockdiagexprmixin.py +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -5,7 +5,7 @@ import dataclassabc from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py index 7947a74..c8a0a23 100644 --- a/polymatrix/expression/mixins/cacheexprmixin.py +++ b/polymatrix/expression/mixins/cacheexprmixin.py @@ -4,7 +4,7 @@ import dataclasses from polymatrix.polymatrix.mixins import PolyMatrixAsDictMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -18,8 +18,8 @@ class CacheExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: # FIXME: return type + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: # FIXME: return type if self in state.cache: return state, state.cache[self] diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py index 121d71c..abbb53d 100644 --- a/polymatrix/expression/mixins/combinationsexprmixin.py +++ b/polymatrix/expression/mixins/combinationsexprmixin.py @@ -5,7 +5,7 @@ from typing import Iterable from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/degreeexprmixin.py b/polymatrix/expression/mixins/degreeexprmixin.py index 522af7c..331c350 100644 --- a/polymatrix/expression/mixins/degreeexprmixin.py +++ b/polymatrix/expression/mixins/degreeexprmixin.py @@ -4,7 +4,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import MatrixIndex, PolyMatrixDict, PolyDict, MonomialIndex from polymatrix.utils.getstacklines import FrameSummary diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 31a1e31..dab1b41 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -4,7 +4,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py index d181543..b7666f0 100644 --- a/polymatrix/expression/mixins/diagexprmixin.py +++ b/polymatrix/expression/mixins/diagexprmixin.py @@ -2,7 +2,7 @@ import abc import dataclassabc from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -24,8 +24,8 @@ class DiagExprMixin(ExpressionBaseMixin): # FIXME: typing def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) # Vector to diagonal matrix diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py index 3477ee1..074caa8 100644 --- a/polymatrix/expression/mixins/divergenceexprmixin.py +++ b/polymatrix/expression/mixins/divergenceexprmixin.py @@ -5,7 +5,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index e914fd5..6ef97a3 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.index import MonomialIndex from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index 881ee83..103c556 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -5,7 +5,7 @@ import math from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py index 7f58339..dc8f699 100644 --- a/polymatrix/expression/mixins/expressionbasemixin.py +++ b/polymatrix/expression/mixins/expressionbasemixin.py @@ -1,9 +1,9 @@ import abc -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin class ExpressionBaseMixin(abc.ABC): @abc.abstractmethod - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: ... + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: ... diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py index 6a65585..74dd930 100644 --- a/polymatrix/expression/mixins/eyeexprmixin.py +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolyDict, MonomialIndex diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py index c8fa427..3cb6e5f 100644 --- a/polymatrix/expression/mixins/filterexprmixin.py +++ b/polymatrix/expression/mixins/filterexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class FilterExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py index 046f191..6edeb78 100644 --- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py +++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # is this class needed? diff --git a/polymatrix/expression/mixins/fromnumbersexprmixin.py b/polymatrix/expression/mixins/fromnumbersexprmixin.py index e34956b..d1c3ccf 100644 --- a/polymatrix/expression/mixins/fromnumbersexprmixin.py +++ b/polymatrix/expression/mixins/fromnumbersexprmixin.py @@ -2,7 +2,7 @@ from abc import abstractmethod from typing_extensions import override from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix @@ -24,7 +24,7 @@ class FromNumbersExprMixin(ExpressionBaseMixin): """ The matrix of numbers in row major order. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: p = PolyMatrixDict.empty() for r, row in enumerate(self.data): diff --git a/polymatrix/expression/mixins/fromnumpyexprmixin.py b/polymatrix/expression/mixins/fromnumpyexprmixin.py index cdf0aa6..cee5cc0 100644 --- a/polymatrix/expression/mixins/fromnumpyexprmixin.py +++ b/polymatrix/expression/mixins/fromnumpyexprmixin.py @@ -5,7 +5,7 @@ from typing_extensions import override, cast from numpy.typing import NDArray from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix @@ -26,7 +26,7 @@ class FromNumpyExprMixin(ExpressionBaseMixin): """ The Numpy array. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: p = PolyMatrixDict.empty() if len(self.data.shape) > 2: diff --git a/polymatrix/expression/mixins/fromstatemonad.py b/polymatrix/expression/mixins/fromstatemonad.py index 254d125..b7d9ea7 100644 --- a/polymatrix/expression/mixins/fromstatemonad.py +++ b/polymatrix/expression/mixins/fromstatemonad.py @@ -3,7 +3,7 @@ from typing_extensions import override from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.statemonad.abc import StateMonad @@ -23,7 +23,7 @@ class FromStateMonadMixin(ExpressionBaseMixin): """ The state monad object. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, expr = self.monad.apply(state) # Case when monad wraps function # f: ExpressionState -> (State, Expression) diff --git a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py index c7d90a1..55a64e5 100644 --- a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py @@ -6,7 +6,7 @@ from polymatrix.expression.utils.getvariableindices import get_variable_indices_ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -21,8 +21,8 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) @@ -43,4 +43,4 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin): shape=(var_index, 1), ) - return state, poly_matrix
\ No newline at end of file + return state, poly_matrix diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index d178415..6d5446a 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -5,7 +5,7 @@ from abc import abstractmethod from typing_extensions import override, cast from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix @@ -36,7 +36,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): """ The sympy objects. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrix]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: # Unpack if it is a sympy matrix if isinstance(self.data, sympy.Matrix): diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py index 463f41c..8d340fd 100644 --- a/polymatrix/expression/mixins/fromtermsexprmixin.py +++ b/polymatrix/expression/mixins/fromtermsexprmixin.py @@ -2,7 +2,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import MonomialData @@ -26,8 +26,8 @@ class FromPolynomialDataExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: data = {coord: dict(polynomial) for coord, polynomial in self.data} poly_matrix = init_poly_matrix( diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py index affe4ec..6fac468 100644 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ b/polymatrix/expression/mixins/getitemexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class GetItemExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py index 52e5814..01b6c47 100644 --- a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py +++ b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py @@ -9,7 +9,7 @@ from polymatrix.expression.utils.getvariableindices import ( from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -29,8 +29,8 @@ class HalfNewtonPolytopeExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, sos_monomials = get_monomial_indices(state, self.monomials) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/integrateexprmixin.py b/polymatrix/expression/mixins/integrateexprmixin.py index 38cf522..5b1db7b 100644 --- a/polymatrix/expression/mixins/integrateexprmixin.py +++ b/polymatrix/expression/mixins/integrateexprmixin.py @@ -4,7 +4,7 @@ import itertools from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -41,10 +41,7 @@ class IntegrateExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) state, variables = get_variable_indices_from_variable(state, self.variables) diff --git a/polymatrix/expression/mixins/legendreseriesmixin.py b/polymatrix/expression/mixins/legendreseriesmixin.py index bec4c6c..d2c8372 100644 --- a/polymatrix/expression/mixins/legendreseriesmixin.py +++ b/polymatrix/expression/mixins/legendreseriesmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index 8354754..d9e8032 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py index 8239947..7ddc546 100644 --- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py +++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py @@ -5,7 +5,7 @@ from numpy import var from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py index c4a5df7..076cbb1 100644 --- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -40,8 +40,8 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state=state) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index 2d86619..10f5737 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -4,7 +4,7 @@ from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.tooperatorexception import to_operator_exception diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py index 650dc68..60c1832 100644 --- a/polymatrix/expression/mixins/maxexprmixin.py +++ b/polymatrix/expression/mixins/maxexprmixin.py @@ -2,7 +2,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -15,8 +15,8 @@ class MaxExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) poly_matrix_data = {} diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index ca49411..5ffecf3 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -4,7 +4,7 @@ from abc import abstractmethod from itertools import product from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolyMatrixDict, MatrixIndex, PolyDict, MonomialIndex, VariableIndex from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -36,7 +36,7 @@ class ParametrizeExprMixin(ExpressionBaseMixin): def name(self) -> str: ... # overwrites the abstract method of `ExpressionBaseMixin` - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) nrows, ncols = underlying.shape diff --git a/polymatrix/expression/mixins/powerexprmixin.py b/polymatrix/expression/mixins/powerexprmixin.py index 4ca1b0b..6d6aae4 100644 --- a/polymatrix/expression/mixins/powerexprmixin.py +++ b/polymatrix/expression/mixins/powerexprmixin.py @@ -7,7 +7,7 @@ from typing_extensions import override from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -30,10 +30,10 @@ class PowerExprMixin(ExpressionBaseMixin): @staticmethod def power( - state: ExpressionStateMixin, + state: ExpressionState, left: ExpressionBaseMixin, right: ExpressionBaseMixin | int | float - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + ) -> tuple[ExpressionState, PolyMatrixMixin]: """ Compute the expression ```left ** right```. """ exponent: int | None = None diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py index 12ea4f2..9db1b9c 100644 --- a/polymatrix/expression/mixins/productexprmixin.py +++ b/polymatrix/expression/mixins/productexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index 144fbab..cba2844 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py index 2717e7e..3ce3b85 100644 --- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -41,8 +41,8 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state=state) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py index e3c6498..7ef01ae 100644 --- a/polymatrix/expression/mixins/repmatexprmixin.py +++ b/polymatrix/expression/mixins/repmatexprmixin.py @@ -2,7 +2,7 @@ import abc import dataclassabc from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -19,8 +19,8 @@ class RepMatExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) # FIXME: move to polymatrix module diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py index 9f43c04..23f295b 100644 --- a/polymatrix/expression/mixins/reshapeexprmixin.py +++ b/polymatrix/expression/mixins/reshapeexprmixin.py @@ -6,7 +6,7 @@ import dataclassabc import numpy as np from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -25,10 +25,11 @@ class ReshapeExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) + # TODO: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class ReshapePolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py index 61db750..35c99d8 100644 --- a/polymatrix/expression/mixins/setelementatexprmixin.py +++ b/polymatrix/expression/mixins/setelementatexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class SetElementAtExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py index 67ade46..785647c 100644 --- a/polymatrix/expression/mixins/squeezeexprmixin.py +++ b/polymatrix/expression/mixins/squeezeexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # remove? diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index d2ab740..a3e76b1 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -7,7 +7,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py index 6bac539..3efb656 100644 --- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.polymatrix.utils.sortmonomials import sort_monomials @@ -29,8 +29,8 @@ class SubtractMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = get_monomial_indices(state, self.underlying) state, sub_monomials = get_monomial_indices(state, self.monomials) diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py index 4ade032..af722a4 100644 --- a/polymatrix/expression/mixins/sumexprmixin.py +++ b/polymatrix/expression/mixins/sumexprmixin.py @@ -4,7 +4,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -22,8 +22,8 @@ class SumExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) poly_matrix_data = collections.defaultdict(dict) diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py index ea9d16f..dd77bef 100644 --- a/polymatrix/expression/mixins/symmetricexprmixin.py +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -1,15 +1,13 @@ import abc import collections -import itertools import dataclassabc -import typing from polymatrix.polymatrix.index import PolyDict from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class SymmetricExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py index 9c8b9b7..51c2e56 100644 --- a/polymatrix/expression/mixins/toconstantexprmixin.py +++ b/polymatrix/expression/mixins/toconstantexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class ToConstantExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py index 4dfa518..865fdf5 100644 --- a/polymatrix/expression/mixins/tosortedvariablesmixin.py +++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py @@ -6,7 +6,7 @@ from polymatrix.expression.utils.getvariableindices import ( from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # to be deleted? diff --git a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py index 3c50f76..65dfffd 100644 --- a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.expression.utils.getvariableindices import ( ) from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -23,8 +23,8 @@ class ToSymmetricMatrixExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) assert underlying.shape[1] == 1 diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py index eaab11a..16139ba 100644 --- a/polymatrix/expression/mixins/transposeexprmixin.py +++ b/polymatrix/expression/mixins/transposeexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.index import PolyDict from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class TransposeExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py index 3171221..629dd45 100644 --- a/polymatrix/expression/mixins/truncateexprmixin.py +++ b/polymatrix/expression/mixins/truncateexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py index 750e320..fe8ad80 100644 --- a/polymatrix/expression/mixins/variablemixin.py +++ b/polymatrix/expression/mixins/variablemixin.py @@ -4,9 +4,8 @@ import typing import itertools from typing_extensions import override -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex from polymatrix.variable.abc import Variable @@ -16,7 +15,7 @@ class VariableMixin(ExpressionBaseMixin, Variable): """ Underlying object for VariableExpression """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state = state.register(self) indices = state.get_indices(self) p = PolyMatrixDict() diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index 8f17b2f..249bb26 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -6,7 +6,7 @@ from polymatrix.polymatrix.index import PolyDict from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class VStackExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index c24b1d2..4509d1e 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -3,7 +3,7 @@ import sympy import numpy as np from polymatrix.expression.expression import Expression -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin diff --git a/polymatrix/expression/utils/getderivativemonomials.py b/polymatrix/expression/utils/getderivativemonomials.py index cf2ae92..de783d9 100644 --- a/polymatrix/expression/utils/getderivativemonomials.py +++ b/polymatrix/expression/utils/getderivativemonomials.py @@ -2,7 +2,7 @@ import collections import dataclasses import itertools -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolynomialData diff --git a/polymatrix/expression/utils/getmonomialindices.py b/polymatrix/expression/utils/getmonomialindices.py index 9145002..6470b5e 100644 --- a/polymatrix/expression/utils/getmonomialindices.py +++ b/polymatrix/expression/utils/getmonomialindices.py @@ -1,7 +1,8 @@ -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +# TODO: mark as deprecated, explain replacement def get_monomial_indices( state: ExpressionState, expression: ExpressionBaseMixin, diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index d87f82e..13b6aa0 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -1,7 +1,7 @@ import itertools import typing -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.utils.deprecation import deprecated diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate.py index 8d1145f..931f282 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING from abc import abstractmethod -from typing import NamedTuple, Iterable +from typing import Any, NamedTuple, Iterable from math import prod -import dataclasses +from dataclassabc import dataclassabc +from dataclasses import replace from polymatrix.variable.abc import Variable from polymatrix.utils.deprecation import deprecated @@ -21,20 +21,18 @@ class IndexRange(NamedTuple): return self.start < other.start -class ExpressionStateMixin( - StateCacheMixin, -): - @property - @abstractmethod - def n_variables(self) -> int: - """ Number of polynomial variables """ +@dataclassabc(frozen=True) +class ExpressionState(StateCacheMixin): + n_variables: int + """ Number of polynomial variables """ - @property - @abstractmethod - def indices(self) -> dict[Variable, IndexRange]: - """ Map from variable objects to their indices. """ + indices: dict[Variable, IndexRange] + """ Map from variable objects to their indices. """ - def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]: + cache: dict + """ Cache for StateCacheMixin """ + + def index(self, var: Variable) -> tuple[ExpressionState, IndexRange]: """ Index a variable and get its index range. """ if not isinstance(var, Variable): raise ValueError("State can only index object of type Variable!") @@ -54,7 +52,7 @@ class ExpressionStateMixin( size = prod(var.shape) index = IndexRange(start=self.n_variables, end=self.n_variables + size) - return dataclasses.replace( + return replace( self, n_variables=self.n_variables + size, indices=self.indices | {var: index} @@ -174,3 +172,10 @@ class ExpressionStateMixin( def get_key_from_offset(self, index: int) -> Variable: return self.get_variable(index) + +def init_expression_state(n_variables: int = 0, indices: dict[Variable, IndexRange] = {}): + return ExpressionState( + n_variables=n_variables, + indices=indices, + cache={}, + ) diff --git a/polymatrix/expressionstate/__init__.py b/polymatrix/expressionstate/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/polymatrix/expressionstate/__init__.py +++ /dev/null diff --git a/polymatrix/expressionstate/abc.py b/polymatrix/expressionstate/abc.py deleted file mode 100644 index bdc4484..0000000 --- a/polymatrix/expressionstate/abc.py +++ /dev/null @@ -1,6 +0,0 @@ -from polymatrix.expressionstate.mixins import ExpressionStateMixin - - -# NP: "state" of an expression that maps indices to variable / parameter objects -class ExpressionState(ExpressionStateMixin): - pass diff --git a/polymatrix/expressionstate/impl.py b/polymatrix/expressionstate/impl.py deleted file mode 100644 index c243a26..0000000 --- a/polymatrix/expressionstate/impl.py +++ /dev/null @@ -1,10 +0,0 @@ -import dataclassabc - -from polymatrix.expressionstate.abc import ExpressionState - - -@dataclassabc.dataclassabc(frozen=True) -class ExpressionStateImpl(ExpressionState): - n_variables: int - indices: dict - cache: dict diff --git a/polymatrix/expressionstate/init.py b/polymatrix/expressionstate/init.py deleted file mode 100644 index d7c7d25..0000000 --- a/polymatrix/expressionstate/init.py +++ /dev/null @@ -1,9 +0,0 @@ -from polymatrix.expressionstate.impl import ExpressionStateImpl - - -def init_expression_state(n_param: int = 0, offset_dict: dict = {}): - return ExpressionStateImpl( - n_variables=n_param, - indices={}, - cache={}, - ) |