diff options
author | Nao Pross <np@0hm.ch> | 2024-05-05 11:29:23 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-05 11:29:23 +0200 |
commit | b189292b458335aecdda807368d5d3e22feceb47 (patch) | |
tree | 5467e91976eb59dcf80148521fb429a8c410108e | |
parent | Fix typo in variable.init (diff) | |
download | polymatrix-b189292b458335aecdda807368d5d3e22feceb47.tar.gz polymatrix-b189292b458335aecdda807368d5d3e22feceb47.zip |
Make expressionstate partially backwards compatible
See also 13f11cb60021d4c143ce8c80e9b0c5027a4bf434
-rw-r--r-- | polymatrix/expressionstate/mixins.py | 47 |
1 files changed, 45 insertions, 2 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index 15f5712..e4da232 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -2,11 +2,12 @@ from __future__ import annotations from typing import TYPE_CHECKING from abc import abstractmethod -from typing import NamedTuple, cast +from typing import NamedTuple, Iterable from math import prod import dataclasses from polymatrix.variable.abc import Variable +from polymatrix.utils.deprecation import deprecated from polymatrix.statemonad.mixins import StateCacheMixin @@ -33,7 +34,7 @@ class ExpressionStateMixin( """ Map from variable objects to their indices. """ def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]: - """ Get the index range of a variable. """ + """ Index a variable and get its index range. """ # Unwrap if wrapped in expression object if not isinstance(var, Variable): raise ValueError("State can only index object of type Variable!") @@ -61,6 +62,26 @@ class ExpressionStateMixin( state, _ = self.index(var) return state + def get_indices(self, var: Variable) -> Iterable[int]: + """ + Get all indices associated to a variable. + + When a variable is not a scalar multiple indices will be associated to + the variables, one for each entry. + """ + if var not in self.indices: + raise IndexError(f"There is no variable with index {index}.") + + yield from range(*self.indices[var]) + + def get_variable(self, index: int) -> Variable: + """ Get the variable object from its index. """ + for variable, (start, end) in self.indices.items(): + if start <= index < end: + return variable + + raise IndexError(f"There is no variable with index {index}.") + def get_name(self, index: int) -> str: """ Get the name of a variable given its index. """ for variable, (start, end) in self.indices.items(): @@ -72,3 +93,25 @@ class ExpressionStateMixin( return variable.name raise IndexError(f"There is no variable with index {index}.") + + # -- Old API --- + + @property + @deprecated("replaced by n_variables") + def n_param(self) -> int: + return self.n_variables + + @property + @deprecated("replaced by indices") + def offset_dict(self): + return self.indices + + @property + @deprecated("Support for auxillary equations was removed") + def auxillary_equations(self): + return {} + + @deprecated("replaced by get_variable") + def get_key_from_offset(self, index: int) -> Variable: + return self.get_variable(index) + |