diff options
author | Nao Pross <np@0hm.ch> | 2024-05-04 16:39:44 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-04 16:39:44 +0200 |
commit | e5a5d02b1e8de2cfb370c3cdd73554fe8a6534e9 (patch) | |
tree | 59f8e7655db39c6771c152530ca6d8861abafaf0 | |
parent | Remove __hash__ dunder hack, the user has to manually unwrap (diff) | |
download | polymatrix-e5a5d02b1e8de2cfb370c3cdd73554fe8a6534e9.tar.gz polymatrix-e5a5d02b1e8de2cfb370c3cdd73554fe8a6534e9.zip |
Remove expression dependency in expressionstate
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expressionstate/mixins.py | 10 |
1 files changed, 2 insertions, 8 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index 836032e..15f5712 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -6,7 +6,6 @@ from typing import NamedTuple, cast from math import prod import dataclasses -from polymatrix.expression.expression import Expression from polymatrix.variable.abc import Variable from polymatrix.statemonad.mixins import StateCacheMixin @@ -34,7 +33,7 @@ class ExpressionStateMixin( """ Map from variable objects to their indices. """ def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]: - """ Get the index of a variable. """ + """ Get the index range of a variable. """ # Unwrap if wrapped in expression object if not isinstance(var, Variable): raise ValueError("State can only index object of type Variable!") @@ -47,13 +46,8 @@ class ExpressionStateMixin( size = prod(var.shape) index = IndexRange(start=self.n_variables, end=self.n_variables + size) - # FIXME: this is the only way the typechecker is happy, and it does - # make sense because ExpressionState is not a dataclass. - # The typechecker should be somehow informed that this is an ABC that - # will eventually become a dataclass, but of course it can't know that - from .impl import ExpressionStateImpl return dataclasses.replace( - cast(ExpressionStateImpl, self), + self, n_variables=self.n_variables + size, indices={**self.indices, var: index} ), index |