summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-05 11:29:23 +0200
committerNao Pross <np@0hm.ch>2024-05-05 11:29:23 +0200
commitb189292b458335aecdda807368d5d3e22feceb47 (patch)
tree5467e91976eb59dcf80148521fb429a8c410108e
parentFix typo in variable.init (diff)
downloadpolymatrix-b189292b458335aecdda807368d5d3e22feceb47.tar.gz
polymatrix-b189292b458335aecdda807368d5d3e22feceb47.zip
Make expressionstate partially backwards compatible
See also 13f11cb60021d4c143ce8c80e9b0c5027a4bf434
-rw-r--r--polymatrix/expressionstate/mixins.py47
1 files changed, 45 insertions, 2 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index 15f5712..e4da232 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -2,11 +2,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from abc import abstractmethod
-from typing import NamedTuple, cast
+from typing import NamedTuple, Iterable
from math import prod
import dataclasses
from polymatrix.variable.abc import Variable
+from polymatrix.utils.deprecation import deprecated
from polymatrix.statemonad.mixins import StateCacheMixin
@@ -33,7 +34,7 @@ class ExpressionStateMixin(
""" Map from variable objects to their indices. """
def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]:
- """ Get the index range of a variable. """
+ """ Index a variable and get its index range. """
# Unwrap if wrapped in expression object
if not isinstance(var, Variable):
raise ValueError("State can only index object of type Variable!")
@@ -61,6 +62,26 @@ class ExpressionStateMixin(
state, _ = self.index(var)
return state
+ def get_indices(self, var: Variable) -> Iterable[int]:
+ """
+ Get all indices associated to a variable.
+
+ When a variable is not a scalar multiple indices will be associated to
+ the variables, one for each entry.
+ """
+ if var not in self.indices:
+ raise IndexError(f"There is no variable with index {index}.")
+
+ yield from range(*self.indices[var])
+
+ def get_variable(self, index: int) -> Variable:
+ """ Get the variable object from its index. """
+ for variable, (start, end) in self.indices.items():
+ if start <= index < end:
+ return variable
+
+ raise IndexError(f"There is no variable with index {index}.")
+
def get_name(self, index: int) -> str:
""" Get the name of a variable given its index. """
for variable, (start, end) in self.indices.items():
@@ -72,3 +93,25 @@ class ExpressionStateMixin(
return variable.name
raise IndexError(f"There is no variable with index {index}.")
+
+ # -- Old API ---
+
+ @property
+ @deprecated("replaced by n_variables")
+ def n_param(self) -> int:
+ return self.n_variables
+
+ @property
+ @deprecated("replaced by indices")
+ def offset_dict(self):
+ return self.indices
+
+ @property
+ @deprecated("Support for auxillary equations was removed")
+ def auxillary_equations(self):
+ return {}
+
+ @deprecated("replaced by get_variable")
+ def get_key_from_offset(self, index: int) -> Variable:
+ return self.get_variable(index)
+