diff options
-rw-r--r-- | polymatrix/expressionstate/mixins.py | 10 |
1 files changed, 2 insertions, 8 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index 836032e..15f5712 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -6,7 +6,6 @@ from typing import NamedTuple, cast from math import prod import dataclasses -from polymatrix.expression.expression import Expression from polymatrix.variable.abc import Variable from polymatrix.statemonad.mixins import StateCacheMixin @@ -34,7 +33,7 @@ class ExpressionStateMixin( """ Map from variable objects to their indices. """ def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]: - """ Get the index of a variable. """ + """ Get the index range of a variable. """ # Unwrap if wrapped in expression object if not isinstance(var, Variable): raise ValueError("State can only index object of type Variable!") @@ -47,13 +46,8 @@ class ExpressionStateMixin( size = prod(var.shape) index = IndexRange(start=self.n_variables, end=self.n_variables + size) - # FIXME: this is the only way the typechecker is happy, and it does - # make sense because ExpressionState is not a dataclass. - # The typechecker should be somehow informed that this is an ABC that - # will eventually become a dataclass, but of course it can't know that - from .impl import ExpressionStateImpl return dataclasses.replace( - cast(ExpressionStateImpl, self), + self, n_variables=self.n_variables + size, indices={**self.indices, var: index} ), index |