summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-04 16:39:44 +0200
committerNao Pross <np@0hm.ch>2024-05-04 16:39:44 +0200
commite5a5d02b1e8de2cfb370c3cdd73554fe8a6534e9 (patch)
tree59f8e7655db39c6771c152530ca6d8861abafaf0
parentRemove __hash__ dunder hack, the user has to manually unwrap (diff)
downloadpolymatrix-e5a5d02b1e8de2cfb370c3cdd73554fe8a6534e9.tar.gz
polymatrix-e5a5d02b1e8de2cfb370c3cdd73554fe8a6534e9.zip
Remove expression dependency in expressionstate
Diffstat (limited to '')
-rw-r--r--polymatrix/expressionstate/mixins.py10
1 files changed, 2 insertions, 8 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index 836032e..15f5712 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -6,7 +6,6 @@ from typing import NamedTuple, cast
from math import prod
import dataclasses
-from polymatrix.expression.expression import Expression
from polymatrix.variable.abc import Variable
from polymatrix.statemonad.mixins import StateCacheMixin
@@ -34,7 +33,7 @@ class ExpressionStateMixin(
""" Map from variable objects to their indices. """
def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]:
- """ Get the index of a variable. """
+ """ Get the index range of a variable. """
# Unwrap if wrapped in expression object
if not isinstance(var, Variable):
raise ValueError("State can only index object of type Variable!")
@@ -47,13 +46,8 @@ class ExpressionStateMixin(
size = prod(var.shape)
index = IndexRange(start=self.n_variables, end=self.n_variables + size)
- # FIXME: this is the only way the typechecker is happy, and it does
- # make sense because ExpressionState is not a dataclass.
- # The typechecker should be somehow informed that this is an ABC that
- # will eventually become a dataclass, but of course it can't know that
- from .impl import ExpressionStateImpl
return dataclasses.replace(
- cast(ExpressionStateImpl, self),
+ self,
n_variables=self.n_variables + size,
indices={**self.indices, var: index}
), index