summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expressionstate/mixins.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index 0e61591..b80669f 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -6,8 +6,8 @@ from typing import NamedTuple, cast
from math import prod
import dataclasses
-if TYPE_CHECKING:
- from polymatrix.expression.mixins.variablemixin import VariableMixin
+from polymatrix.expression.expression import Expression
+from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.statemonad.mixins import StateCacheMixin
@@ -32,8 +32,16 @@ class ExpressionStateMixin(
def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]:
""" Get the index of a variable. """
+ # Unwrap if wrapped in expression object
+ if isinstance(var, Expression):
+ var = var.underlying
+
+ if not isinstance(var, VariableMixin):
+ raise ValueError("State can only index object of type VariableMixin or expressions "
+ "that contain a single VariableMixin object!")
+
# Check if already in there
- if var in self.indices:
+ if var in self.indices.keys():
return self, self.indices[var]
# If not save new index