summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-04-15 19:02:01 +0200
committerNao Pross <np@0hm.ch>2024-04-15 19:07:17 +0200
commit13f11cb60021d4c143ce8c80e9b0c5027a4bf434 (patch)
treefc393153e5f26730abd2bbcd5155459fcbe13575
parentMake PolyMatrixDict accept indexing p[row, col] (diff)
downloadpolymatrix-13f11cb60021d4c143ce8c80e9b0c5027a4bf434.tar.gz
polymatrix-13f11cb60021d4c143ce8c80e9b0c5027a4bf434.zip
Update ExpressionState, breaking change
There is no (easy) way to make this change backwards compatible, the FromX classes will be adapted to the new format of the state.
-rw-r--r--polymatrix/expressionstate/impl.py5
-rw-r--r--polymatrix/expressionstate/init.py18
-rw-r--r--polymatrix/expressionstate/mixins.py113
3 files changed, 58 insertions, 78 deletions
diff --git a/polymatrix/expressionstate/impl.py b/polymatrix/expressionstate/impl.py
index c25dae5..c243a26 100644
--- a/polymatrix/expressionstate/impl.py
+++ b/polymatrix/expressionstate/impl.py
@@ -5,7 +5,6 @@ from polymatrix.expressionstate.abc import ExpressionState
@dataclassabc.dataclassabc(frozen=True)
class ExpressionStateImpl(ExpressionState):
- n_param: int
- offset_dict: dict
- auxillary_equations: dict[int, dict[tuple[int], float]]
+ n_variables: int
+ indices: dict
cache: dict
diff --git a/polymatrix/expressionstate/init.py b/polymatrix/expressionstate/init.py
index 8fdbabf..d7c7d25 100644
--- a/polymatrix/expressionstate/init.py
+++ b/polymatrix/expressionstate/init.py
@@ -1,21 +1,9 @@
from polymatrix.expressionstate.impl import ExpressionStateImpl
-def init_expression_state(
- n_param: int = None,
- offset_dict: dict = None,
-):
- # FIXME: just set the defaults above instead of None which btw is not
- # allowed by the type checker ("implicit none is not allowed")
- if n_param is None:
- n_param = 0
-
- if offset_dict is None:
- offset_dict = {}
-
+def init_expression_state(n_param: int = 0, offset_dict: dict = {}):
return ExpressionStateImpl(
- n_param=n_param,
- offset_dict=offset_dict,
- auxillary_equations={},
+ n_variables=n_param,
+ indices={},
cache={},
)
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index 7122e80..ffaea0a 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -1,77 +1,70 @@
-import abc
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+from abc import abstractmethod
+from typing import NamedTuple
+from math import prod
import dataclasses
-import typing
+
+if TYPE_CHECKING:
+ from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.statemonad.mixins import StateCacheMixin
+# TODO: move to typing submodule
+class IndexRange(NamedTuple):
+ start: int
+ end: int
+
-# NP: "state" of an expression that maps indices to variable / parameter objects
class ExpressionStateMixin(
StateCacheMixin,
):
- @property
- @abc.abstractmethod
- def n_param(self) -> int:
- """
- number of parameters used in polynomial matrix expressions
- """
- ...
+ # -- New API --
@property
- @abc.abstractmethod
- def offset_dict(self) -> dict[tuple[typing.Any], tuple[int, int]]:
- """
- a variable consists of one or more parameters indexed by a start
- and an end index
- """
- # NP: I call a thing (start, end) a _range_ or _interval_ of indices to index multiple varilables
- # NP: offset_dict is confusing IMHO, consider renaming
- ...
-
+ @abstractmethod
+ def n_variables(self) -> int:
+ """ Number of polynomial variables """
+
@property
- @abc.abstractmethod
- def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:
- # NP: TODO explanation of how auxiliary equaitons work
- ...
+ @abstractmethod
+ def indices(self) -> dict[VariableMixin, IndexRange]:
+ """ Map from variable objects to their indices. """
- # NP: get a variable name from the offset, hovever you can only ever
- # NP: get names using offsets, so maybe rename to something more intutive like
- # FIXME: rename to get_variable_name() or get_parameter_name()
- def get_name_from_offset(self, offset: int):
- for variable, (start, end) in self.offset_dict.items():
- if start <= offset < end:
- return f"{str(variable)}_{offset-start}"
+ def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]:
+ """ Get the index of a variable. """
+ # Check if already in there
+ if var in self.indices:
+ return self, self.indices[var]
- # NP: key does not mean anything for someone who does not know how this class works inside
- # FIXME: rename to just get_variable() or get_parameter()
- def get_key_from_offset(self, offset: int):
- for variable, (start, end) in self.offset_dict.items():
- if start <= offset < end:
- return variable
+ # If not save new index
+ size = prod(var.shape)
+ index = IndexRange(start=self.n_variables, end=self.n_variables + size)
- # NP: register a variable / parameter into the state object
- # NP: why are you allowed to not give a key (use case)? also, rename key to variable / parameter
- # NP: register() is good, other good names are index(), index_variable(), index_parameter()
- def register(
- self,
- n_param: int,
- key: typing.Any = None, # NP: Any is close to useless, specify type
- ) -> "ExpressionStateMixin":
- if key is None:
- updated_state = dataclasses.replace(
- self,
- n_param=self.n_param + n_param,
- )
+ return dataclasses.replace( # type: ignore[type-var]
+ self,
+ n_variables=self.n_variables + size,
+ indices={**self.indices, var: index}
+ ), index
+
+ def register(self, var: VariableMixin) -> ExpressionStateMixin:
+ """
+ Create an index for a variable, but does not return the index. If you
+ want the index use
+ :py:meth:`polymatrix.expressionstate.mixins.ExpressionStateMixin.index`
+ """
+ state, _ = self.index(var)
+ return state
- elif key not in self.offset_dict:
- updated_state = dataclasses.replace(
- self,
- offset_dict=self.offset_dict
- | {key: (self.n_param, self.n_param + n_param)},
- n_param=self.n_param + n_param,
- )
+ def get_name(self, index: int) -> str:
+ """ Get the name of a variable given its index. """
+ for variable, (start, end) in self.indices.items():
+ if start <= index < end:
+ # Variable is not scalar
+ if end - start > 1:
+ return f"{variable.name}_{index - start}"
- else:
- updated_state = self
+ return variable.name
- return updated_state
+ raise IndexError(f"There is no variable with index {index}.")