summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-04-29 15:59:44 +0200
committerNao Pross <np@0hm.ch>2024-04-29 15:59:44 +0200
commitf402569ad74d4bd7bbcf254855d0f78807df7091 (patch)
treea3f2cdc2b9537269f132b6d33d4ee5b2c93f3f41
parentFix circular imports caused by type annotations (diff)
downloadpolymatrix-f402569ad74d4bd7bbcf254855d0f78807df7091.tar.gz
polymatrix-f402569ad74d4bd7bbcf254855d0f78807df7091.zip
Fix bug in ExpressionState.index
If we do x = polymatrix.from_names("x") state = polymatrix.make_state() state, idx = state.index(x) There is problem because the state index should be a dictionary with keys of type VariableMixin, but x is of type Expression, so we need to take its underlying
-rw-r--r--polymatrix/expressionstate/mixins.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index 0e61591..b80669f 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -6,8 +6,8 @@ from typing import NamedTuple, cast
from math import prod
import dataclasses
-if TYPE_CHECKING:
- from polymatrix.expression.mixins.variablemixin import VariableMixin
+from polymatrix.expression.expression import Expression
+from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.statemonad.mixins import StateCacheMixin
@@ -32,8 +32,16 @@ class ExpressionStateMixin(
def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]:
""" Get the index of a variable. """
+ # Unwrap if wrapped in expression object
+ if isinstance(var, Expression):
+ var = var.underlying
+
+ if not isinstance(var, VariableMixin):
+ raise ValueError("State can only index object of type VariableMixin or expressions "
+ "that contain a single VariableMixin object!")
+
# Check if already in there
- if var in self.indices:
+ if var in self.indices.keys():
return self, self.indices[var]
# If not save new index