diff options
-rw-r--r-- | polymatrix/expressionstate/mixins.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index 0e61591..b80669f 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -6,8 +6,8 @@ from typing import NamedTuple, cast from math import prod import dataclasses -if TYPE_CHECKING: - from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.expression.expression import Expression +from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.statemonad.mixins import StateCacheMixin @@ -32,8 +32,16 @@ class ExpressionStateMixin( def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]: """ Get the index of a variable. """ + # Unwrap if wrapped in expression object + if isinstance(var, Expression): + var = var.underlying + + if not isinstance(var, VariableMixin): + raise ValueError("State can only index object of type VariableMixin or expressions " + "that contain a single VariableMixin object!") + # Check if already in there - if var in self.indices: + if var in self.indices.keys(): return self, self.indices[var] # If not save new index |