From f402569ad74d4bd7bbcf254855d0f78807df7091 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 29 Apr 2024 15:59:44 +0200 Subject: Fix bug in ExpressionState.index If we do x = polymatrix.from_names("x") state = polymatrix.make_state() state, idx = state.index(x) There is problem because the state index should be a dictionary with keys of type VariableMixin, but x is of type Expression, so we need to take its underlying --- polymatrix/expressionstate/mixins.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index 0e61591..b80669f 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -6,8 +6,8 @@ from typing import NamedTuple, cast from math import prod import dataclasses -if TYPE_CHECKING: - from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.expression.expression import Expression +from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.statemonad.mixins import StateCacheMixin @@ -32,8 +32,16 @@ class ExpressionStateMixin( def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]: """ Get the index of a variable. """ + # Unwrap if wrapped in expression object + if isinstance(var, Expression): + var = var.underlying + + if not isinstance(var, VariableMixin): + raise ValueError("State can only index object of type VariableMixin or expressions " + "that contain a single VariableMixin object!") + # Check if already in there - if var in self.indices: + if var in self.indices.keys(): return self, self.indices[var] # If not save new index -- cgit v1.2.1