diff options
author | Nao Pross <np@0hm.ch> | 2024-05-11 19:32:34 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-12 15:18:35 +0200 |
commit | f094e4d91b44fc1e8b5f11aac2dd8073ba024fc8 (patch) | |
tree | 6d43f6a776fa6d15e69a68c330497aedf3422d93 | |
parent | Delete ParametrizeMatrixExprMixin, update ParametrizeExprMixin to work with a... (diff) | |
download | polymatrix-f094e4d91b44fc1e8b5f11aac2dd8073ba024fc8.tar.gz polymatrix-f094e4d91b44fc1e8b5f11aac2dd8073ba024fc8.zip |
Collapse ExpressionState to a single class
ExpressionState has the usual mixin, whereby one splits the ABC into
smaller parts called mixins, however this class only had one mixin
and it makes no sense to have three classes instead of one. This class
was simply not complex enough to make it worth maintaining all these
files.
I do not see a use case that would extend ExpressionState in a way that
would introduce enough complexity for the mixin structure to be necessary.
But it that were to happen in the future it is easy to reintroduce this
structure I just deleted.
In addition to reducing complexity this also fixes the inconsistent use
of ExpressionStateMixin and ExpressionState across the mixin classes
of Expression.
62 files changed, 122 insertions, 148 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 94da22a..62fc4a6 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -1,5 +1,5 @@ -from polymatrix.expressionstate.abc import ExpressionState as internal_ExpressionState -from polymatrix.expressionstate.init import ( +from polymatrix.expressionstate import ( + ExpressionState as internal_ExpressionState, init_expression_state as internal_init_expression_state, ) diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py index 1eda6fe..52040a1 100644 --- a/polymatrix/denserepr/from_.py +++ b/polymatrix/denserepr/from_.py @@ -3,7 +3,7 @@ import numpy as np from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.getvariableindices import ( diff --git a/polymatrix/denserepr/impl.py b/polymatrix/denserepr/impl.py index 9f53c8b..4d0835e 100644 --- a/polymatrix/denserepr/impl.py +++ b/polymatrix/denserepr/impl.py @@ -4,7 +4,7 @@ import typing import numpy as np import scipy.sparse -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 415f16f..2e7cf11 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -8,17 +8,15 @@ from abc import ABC, abstractmethod from dataclassabc import dataclassabc from typing_extensions import override - -from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.variable.abc import Variable import polymatrix.expression.init -from polymatrix.utils.getstacklines import get_stack_lines - -from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.utils.getstacklines import get_stack_lines +from polymatrix.variable.abc import Variable + from polymatrix.expression.op import ( diff, integrate, diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index a51bae5..4b0a9c9 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.index import PolyMatrixDict from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py index dece9c3..45e5e1b 100644 --- a/polymatrix/expression/mixins/blockdiagexprmixin.py +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -5,7 +5,7 @@ import dataclassabc from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py index 7947a74..c8a0a23 100644 --- a/polymatrix/expression/mixins/cacheexprmixin.py +++ b/polymatrix/expression/mixins/cacheexprmixin.py @@ -4,7 +4,7 @@ import dataclasses from polymatrix.polymatrix.mixins import PolyMatrixAsDictMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -18,8 +18,8 @@ class CacheExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: # FIXME: return type + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: # FIXME: return type if self in state.cache: return state, state.cache[self] diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py index 121d71c..abbb53d 100644 --- a/polymatrix/expression/mixins/combinationsexprmixin.py +++ b/polymatrix/expression/mixins/combinationsexprmixin.py @@ -5,7 +5,7 @@ from typing import Iterable from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/degreeexprmixin.py b/polymatrix/expression/mixins/degreeexprmixin.py index 522af7c..331c350 100644 --- a/polymatrix/expression/mixins/degreeexprmixin.py +++ b/polymatrix/expression/mixins/degreeexprmixin.py @@ -4,7 +4,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import MatrixIndex, PolyMatrixDict, PolyDict, MonomialIndex from polymatrix.utils.getstacklines import FrameSummary diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 31a1e31..dab1b41 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -4,7 +4,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py index d181543..b7666f0 100644 --- a/polymatrix/expression/mixins/diagexprmixin.py +++ b/polymatrix/expression/mixins/diagexprmixin.py @@ -2,7 +2,7 @@ import abc import dataclassabc from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -24,8 +24,8 @@ class DiagExprMixin(ExpressionBaseMixin): # FIXME: typing def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) # Vector to diagonal matrix diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py index 3477ee1..074caa8 100644 --- a/polymatrix/expression/mixins/divergenceexprmixin.py +++ b/polymatrix/expression/mixins/divergenceexprmixin.py @@ -5,7 +5,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index e914fd5..6ef97a3 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.index import MonomialIndex from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index 881ee83..103c556 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -5,7 +5,7 @@ import math from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py index 7f58339..dc8f699 100644 --- a/polymatrix/expression/mixins/expressionbasemixin.py +++ b/polymatrix/expression/mixins/expressionbasemixin.py @@ -1,9 +1,9 @@ import abc -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin class ExpressionBaseMixin(abc.ABC): @abc.abstractmethod - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: ... + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: ... diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py index 6a65585..74dd930 100644 --- a/polymatrix/expression/mixins/eyeexprmixin.py +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolyDict, MonomialIndex diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py index c8fa427..3cb6e5f 100644 --- a/polymatrix/expression/mixins/filterexprmixin.py +++ b/polymatrix/expression/mixins/filterexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class FilterExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py index 046f191..6edeb78 100644 --- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py +++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # is this class needed? diff --git a/polymatrix/expression/mixins/fromnumbersexprmixin.py b/polymatrix/expression/mixins/fromnumbersexprmixin.py index e34956b..d1c3ccf 100644 --- a/polymatrix/expression/mixins/fromnumbersexprmixin.py +++ b/polymatrix/expression/mixins/fromnumbersexprmixin.py @@ -2,7 +2,7 @@ from abc import abstractmethod from typing_extensions import override from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix @@ -24,7 +24,7 @@ class FromNumbersExprMixin(ExpressionBaseMixin): """ The matrix of numbers in row major order. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: p = PolyMatrixDict.empty() for r, row in enumerate(self.data): diff --git a/polymatrix/expression/mixins/fromnumpyexprmixin.py b/polymatrix/expression/mixins/fromnumpyexprmixin.py index cdf0aa6..cee5cc0 100644 --- a/polymatrix/expression/mixins/fromnumpyexprmixin.py +++ b/polymatrix/expression/mixins/fromnumpyexprmixin.py @@ -5,7 +5,7 @@ from typing_extensions import override, cast from numpy.typing import NDArray from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_poly_matrix @@ -26,7 +26,7 @@ class FromNumpyExprMixin(ExpressionBaseMixin): """ The Numpy array. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: p = PolyMatrixDict.empty() if len(self.data.shape) > 2: diff --git a/polymatrix/expression/mixins/fromstatemonad.py b/polymatrix/expression/mixins/fromstatemonad.py index 254d125..b7d9ea7 100644 --- a/polymatrix/expression/mixins/fromstatemonad.py +++ b/polymatrix/expression/mixins/fromstatemonad.py @@ -3,7 +3,7 @@ from typing_extensions import override from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.statemonad.abc import StateMonad @@ -23,7 +23,7 @@ class FromStateMonadMixin(ExpressionBaseMixin): """ The state monad object. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, expr = self.monad.apply(state) # Case when monad wraps function # f: ExpressionState -> (State, Expression) diff --git a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py index c7d90a1..55a64e5 100644 --- a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py @@ -6,7 +6,7 @@ from polymatrix.expression.utils.getvariableindices import get_variable_indices_ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -21,8 +21,8 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) @@ -43,4 +43,4 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin): shape=(var_index, 1), ) - return state, poly_matrix
\ No newline at end of file + return state, poly_matrix diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index d178415..6d5446a 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -5,7 +5,7 @@ from abc import abstractmethod from typing_extensions import override, cast from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix @@ -36,7 +36,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): """ The sympy objects. """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrix]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: # Unpack if it is a sympy matrix if isinstance(self.data, sympy.Matrix): diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py index 463f41c..8d340fd 100644 --- a/polymatrix/expression/mixins/fromtermsexprmixin.py +++ b/polymatrix/expression/mixins/fromtermsexprmixin.py @@ -2,7 +2,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import MonomialData @@ -26,8 +26,8 @@ class FromPolynomialDataExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: data = {coord: dict(polynomial) for coord, polynomial in self.data} poly_matrix = init_poly_matrix( diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py index affe4ec..6fac468 100644 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ b/polymatrix/expression/mixins/getitemexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class GetItemExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py index 52e5814..01b6c47 100644 --- a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py +++ b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py @@ -9,7 +9,7 @@ from polymatrix.expression.utils.getvariableindices import ( from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -29,8 +29,8 @@ class HalfNewtonPolytopeExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, sos_monomials = get_monomial_indices(state, self.monomials) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/integrateexprmixin.py b/polymatrix/expression/mixins/integrateexprmixin.py index 38cf522..5b1db7b 100644 --- a/polymatrix/expression/mixins/integrateexprmixin.py +++ b/polymatrix/expression/mixins/integrateexprmixin.py @@ -4,7 +4,7 @@ import itertools from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -41,10 +41,7 @@ class IntegrateExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) state, variables = get_variable_indices_from_variable(state, self.variables) diff --git a/polymatrix/expression/mixins/legendreseriesmixin.py b/polymatrix/expression/mixins/legendreseriesmixin.py index bec4c6c..d2c8372 100644 --- a/polymatrix/expression/mixins/legendreseriesmixin.py +++ b/polymatrix/expression/mixins/legendreseriesmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index 8354754..d9e8032 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py index 8239947..7ddc546 100644 --- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py +++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py @@ -5,7 +5,7 @@ from numpy import var from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py index c4a5df7..076cbb1 100644 --- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -40,8 +40,8 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state=state) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index 2d86619..10f5737 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -4,7 +4,7 @@ from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.tooperatorexception import to_operator_exception diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py index 650dc68..60c1832 100644 --- a/polymatrix/expression/mixins/maxexprmixin.py +++ b/polymatrix/expression/mixins/maxexprmixin.py @@ -2,7 +2,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -15,8 +15,8 @@ class MaxExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) poly_matrix_data = {} diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index ca49411..5ffecf3 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -4,7 +4,7 @@ from abc import abstractmethod from itertools import product from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolyMatrixDict, MatrixIndex, PolyDict, MonomialIndex, VariableIndex from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -36,7 +36,7 @@ class ParametrizeExprMixin(ExpressionBaseMixin): def name(self) -> str: ... # overwrites the abstract method of `ExpressionBaseMixin` - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) nrows, ncols = underlying.shape diff --git a/polymatrix/expression/mixins/powerexprmixin.py b/polymatrix/expression/mixins/powerexprmixin.py index 4ca1b0b..6d6aae4 100644 --- a/polymatrix/expression/mixins/powerexprmixin.py +++ b/polymatrix/expression/mixins/powerexprmixin.py @@ -7,7 +7,7 @@ from typing_extensions import override from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -30,10 +30,10 @@ class PowerExprMixin(ExpressionBaseMixin): @staticmethod def power( - state: ExpressionStateMixin, + state: ExpressionState, left: ExpressionBaseMixin, right: ExpressionBaseMixin | int | float - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + ) -> tuple[ExpressionState, PolyMatrixMixin]: """ Compute the expression ```left ** right```. """ exponent: int | None = None diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py index 12ea4f2..9db1b9c 100644 --- a/polymatrix/expression/mixins/productexprmixin.py +++ b/polymatrix/expression/mixins/productexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.getstacklines import FrameSummary from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index 144fbab..cba2844 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py index 2717e7e..3ce3b85 100644 --- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, @@ -41,8 +41,8 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state=state) state, variable_indices = get_variable_indices_from_variable( state, self.variables diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py index e3c6498..7ef01ae 100644 --- a/polymatrix/expression/mixins/repmatexprmixin.py +++ b/polymatrix/expression/mixins/repmatexprmixin.py @@ -2,7 +2,7 @@ import abc import dataclassabc from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -19,8 +19,8 @@ class RepMatExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) # FIXME: move to polymatrix module diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py index 9f43c04..23f295b 100644 --- a/polymatrix/expression/mixins/reshapeexprmixin.py +++ b/polymatrix/expression/mixins/reshapeexprmixin.py @@ -6,7 +6,7 @@ import dataclassabc import numpy as np from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyDict @@ -25,10 +25,11 @@ class ReshapeExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) + # TODO: move to polymatrix module @dataclassabc.dataclassabc(frozen=True) class ReshapePolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py index 61db750..35c99d8 100644 --- a/polymatrix/expression/mixins/setelementatexprmixin.py +++ b/polymatrix/expression/mixins/setelementatexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class SetElementAtExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py index 67ade46..785647c 100644 --- a/polymatrix/expression/mixins/squeezeexprmixin.py +++ b/polymatrix/expression/mixins/squeezeexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # remove? diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index d2ab740..a3e76b1 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -7,7 +7,7 @@ import typing from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py index 6bac539..3efb656 100644 --- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py @@ -3,7 +3,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.polymatrix.utils.sortmonomials import sort_monomials @@ -29,8 +29,8 @@ class SubtractMonomialsExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = get_monomial_indices(state, self.underlying) state, sub_monomials = get_monomial_indices(state, self.monomials) diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py index 4ade032..af722a4 100644 --- a/polymatrix/expression/mixins/sumexprmixin.py +++ b/polymatrix/expression/mixins/sumexprmixin.py @@ -4,7 +4,7 @@ import dataclasses from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -22,8 +22,8 @@ class SumExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) poly_matrix_data = collections.defaultdict(dict) diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py index ea9d16f..dd77bef 100644 --- a/polymatrix/expression/mixins/symmetricexprmixin.py +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -1,15 +1,13 @@ import abc import collections -import itertools import dataclassabc -import typing from polymatrix.polymatrix.index import PolyDict from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class SymmetricExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py index 9c8b9b7..51c2e56 100644 --- a/polymatrix/expression/mixins/toconstantexprmixin.py +++ b/polymatrix/expression/mixins/toconstantexprmixin.py @@ -4,7 +4,7 @@ import collections from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class ToConstantExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py index 4dfa518..865fdf5 100644 --- a/polymatrix/expression/mixins/tosortedvariablesmixin.py +++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py @@ -6,7 +6,7 @@ from polymatrix.expression.utils.getvariableindices import ( from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState # to be deleted? diff --git a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py index 3c50f76..65dfffd 100644 --- a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.expression.utils.getvariableindices import ( ) from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -23,8 +23,8 @@ class ToSymmetricMatrixExprMixin(ExpressionBaseMixin): # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) assert underlying.shape[1] == 1 diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py index eaab11a..16139ba 100644 --- a/polymatrix/expression/mixins/transposeexprmixin.py +++ b/polymatrix/expression/mixins/transposeexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.index import PolyDict from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class TransposeExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py index 3171221..629dd45 100644 --- a/polymatrix/expression/mixins/truncateexprmixin.py +++ b/polymatrix/expression/mixins/truncateexprmixin.py @@ -3,7 +3,7 @@ import abc from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import ( get_variable_indices_from_variable, ) diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py index 750e320..fe8ad80 100644 --- a/polymatrix/expression/mixins/variablemixin.py +++ b/polymatrix/expression/mixins/variablemixin.py @@ -4,9 +4,8 @@ import typing import itertools from typing_extensions import override -from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex from polymatrix.variable.abc import Variable @@ -16,7 +15,7 @@ class VariableMixin(ExpressionBaseMixin, Variable): """ Underlying object for VariableExpression """ @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state = state.register(self) indices = state.get_indices(self) p = PolyMatrixDict() diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index 8f17b2f..249bb26 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -6,7 +6,7 @@ from polymatrix.polymatrix.index import PolyDict from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState class VStackExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index c24b1d2..4509d1e 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -3,7 +3,7 @@ import sympy import numpy as np from polymatrix.expression.expression import Expression -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin diff --git a/polymatrix/expression/utils/getderivativemonomials.py b/polymatrix/expression/utils/getderivativemonomials.py index cf2ae92..de783d9 100644 --- a/polymatrix/expression/utils/getderivativemonomials.py +++ b/polymatrix/expression/utils/getderivativemonomials.py @@ -2,7 +2,7 @@ import collections import dataclasses import itertools -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolynomialData diff --git a/polymatrix/expression/utils/getmonomialindices.py b/polymatrix/expression/utils/getmonomialindices.py index 9145002..6470b5e 100644 --- a/polymatrix/expression/utils/getmonomialindices.py +++ b/polymatrix/expression/utils/getmonomialindices.py @@ -1,7 +1,8 @@ -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +# TODO: mark as deprecated, explain replacement def get_monomial_indices( state: ExpressionState, expression: ExpressionBaseMixin, diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index d87f82e..13b6aa0 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -1,7 +1,7 @@ import itertools import typing -from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.utils.deprecation import deprecated diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate.py index 8d1145f..931f282 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING from abc import abstractmethod -from typing import NamedTuple, Iterable +from typing import Any, NamedTuple, Iterable from math import prod -import dataclasses +from dataclassabc import dataclassabc +from dataclasses import replace from polymatrix.variable.abc import Variable from polymatrix.utils.deprecation import deprecated @@ -21,20 +21,18 @@ class IndexRange(NamedTuple): return self.start < other.start -class ExpressionStateMixin( - StateCacheMixin, -): - @property - @abstractmethod - def n_variables(self) -> int: - """ Number of polynomial variables """ +@dataclassabc(frozen=True) +class ExpressionState(StateCacheMixin): + n_variables: int + """ Number of polynomial variables """ - @property - @abstractmethod - def indices(self) -> dict[Variable, IndexRange]: - """ Map from variable objects to their indices. """ + indices: dict[Variable, IndexRange] + """ Map from variable objects to their indices. """ - def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]: + cache: dict + """ Cache for StateCacheMixin """ + + def index(self, var: Variable) -> tuple[ExpressionState, IndexRange]: """ Index a variable and get its index range. """ if not isinstance(var, Variable): raise ValueError("State can only index object of type Variable!") @@ -54,7 +52,7 @@ class ExpressionStateMixin( size = prod(var.shape) index = IndexRange(start=self.n_variables, end=self.n_variables + size) - return dataclasses.replace( + return replace( self, n_variables=self.n_variables + size, indices=self.indices | {var: index} @@ -174,3 +172,10 @@ class ExpressionStateMixin( def get_key_from_offset(self, index: int) -> Variable: return self.get_variable(index) + +def init_expression_state(n_variables: int = 0, indices: dict[Variable, IndexRange] = {}): + return ExpressionState( + n_variables=n_variables, + indices=indices, + cache={}, + ) diff --git a/polymatrix/expressionstate/__init__.py b/polymatrix/expressionstate/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/polymatrix/expressionstate/__init__.py +++ /dev/null diff --git a/polymatrix/expressionstate/abc.py b/polymatrix/expressionstate/abc.py deleted file mode 100644 index bdc4484..0000000 --- a/polymatrix/expressionstate/abc.py +++ /dev/null @@ -1,6 +0,0 @@ -from polymatrix.expressionstate.mixins import ExpressionStateMixin - - -# NP: "state" of an expression that maps indices to variable / parameter objects -class ExpressionState(ExpressionStateMixin): - pass diff --git a/polymatrix/expressionstate/impl.py b/polymatrix/expressionstate/impl.py deleted file mode 100644 index c243a26..0000000 --- a/polymatrix/expressionstate/impl.py +++ /dev/null @@ -1,10 +0,0 @@ -import dataclassabc - -from polymatrix.expressionstate.abc import ExpressionState - - -@dataclassabc.dataclassabc(frozen=True) -class ExpressionStateImpl(ExpressionState): - n_variables: int - indices: dict - cache: dict diff --git a/polymatrix/expressionstate/init.py b/polymatrix/expressionstate/init.py deleted file mode 100644 index d7c7d25..0000000 --- a/polymatrix/expressionstate/init.py +++ /dev/null @@ -1,9 +0,0 @@ -from polymatrix.expressionstate.impl import ExpressionStateImpl - - -def init_expression_state(n_param: int = 0, offset_dict: dict = {}): - return ExpressionStateImpl( - n_variables=n_param, - indices={}, - cache={}, - ) |