diff options
author | Nao Pross <np@0hm.ch> | 2024-04-29 15:59:44 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-04-29 15:59:44 +0200 |
commit | f402569ad74d4bd7bbcf254855d0f78807df7091 (patch) | |
tree | a3f2cdc2b9537269f132b6d33d4ee5b2c93f3f41 | |
parent | Fix circular imports caused by type annotations (diff) | |
download | polymatrix-f402569ad74d4bd7bbcf254855d0f78807df7091.tar.gz polymatrix-f402569ad74d4bd7bbcf254855d0f78807df7091.zip |
Fix bug in ExpressionState.index
If we do
x = polymatrix.from_names("x")
state = polymatrix.make_state()
state, idx = state.index(x)
There is problem because the state index should be a dictionary with
keys of type VariableMixin, but x is of type Expression, so we need
to take its underlying
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expressionstate/mixins.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index 0e61591..b80669f 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -6,8 +6,8 @@ from typing import NamedTuple, cast from math import prod import dataclasses -if TYPE_CHECKING: - from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.expression.expression import Expression +from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.statemonad.mixins import StateCacheMixin @@ -32,8 +32,16 @@ class ExpressionStateMixin( def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]: """ Get the index of a variable. """ + # Unwrap if wrapped in expression object + if isinstance(var, Expression): + var = var.underlying + + if not isinstance(var, VariableMixin): + raise ValueError("State can only index object of type VariableMixin or expressions " + "that contain a single VariableMixin object!") + # Check if already in there - if var in self.indices: + if var in self.indices.keys(): return self, self.indices[var] # If not save new index |