summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-04-15 19:02:01 +0200
committerNao Pross <np@0hm.ch>2024-04-15 19:07:17 +0200
commit13f11cb60021d4c143ce8c80e9b0c5027a4bf434 (patch)
treefc393153e5f26730abd2bbcd5155459fcbe13575
parentMake PolyMatrixDict accept indexing p[row, col] (diff)
downloadpolymatrix-13f11cb60021d4c143ce8c80e9b0c5027a4bf434.tar.gz
polymatrix-13f11cb60021d4c143ce8c80e9b0c5027a4bf434.zip
Update ExpressionState, breaking change
There is no (easy) way to make this change backwards compatible, the FromX classes will be adapted to the new format of the state.
Diffstat (limited to '')
-rw-r--r--polymatrix/expressionstate/impl.py5
-rw-r--r--polymatrix/expressionstate/init.py18
-rw-r--r--polymatrix/expressionstate/mixins.py113
3 files changed, 58 insertions, 78 deletions
diff --git a/polymatrix/expressionstate/impl.py b/polymatrix/expressionstate/impl.py
index c25dae5..c243a26 100644
--- a/polymatrix/expressionstate/impl.py
+++ b/polymatrix/expressionstate/impl.py
@@ -5,7 +5,6 @@ from polymatrix.expressionstate.abc import ExpressionState
@dataclassabc.dataclassabc(frozen=True)
class ExpressionStateImpl(ExpressionState):
- n_param: int
- offset_dict: dict
- auxillary_equations: dict[int, dict[tuple[int], float]]
+ n_variables: int
+ indices: dict
cache: dict
diff --git a/polymatrix/expressionstate/init.py b/polymatrix/expressionstate/init.py
index 8fdbabf..d7c7d25 100644
--- a/polymatrix/expressionstate/init.py
+++ b/polymatrix/expressionstate/init.py
@@ -1,21 +1,9 @@
from polymatrix.expressionstate.impl import ExpressionStateImpl
-def init_expression_state(
- n_param: int = None,
- offset_dict: dict = None,
-):
- # FIXME: just set the defaults above instead of None which btw is not
- # allowed by the type checker ("implicit none is not allowed")
- if n_param is None:
- n_param = 0
-
- if offset_dict is None:
- offset_dict = {}
-
+def init_expression_state(n_param: int = 0, offset_dict: dict = {}):
return ExpressionStateImpl(
- n_param=n_param,
- offset_dict=offset_dict,
- auxillary_equations={},
+ n_variables=n_param,
+ indices={},
cache={},
)
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index 7122e80..ffaea0a 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -1,77 +1,70 @@
-import abc
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+from abc import abstractmethod
+from typing import NamedTuple
+from math import prod
import dataclasses
-import typing
+
+if TYPE_CHECKING:
+ from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.statemonad.mixins import StateCacheMixin
+# TODO: move to typing submodule
+class IndexRange(NamedTuple):
+ start: int
+ end: int
+
-# NP: "state" of an expression that maps indices to variable / parameter objects
class ExpressionStateMixin(
StateCacheMixin,
):
- @property
- @abc.abstractmethod
- def n_param(self) -> int:
- """
- number of parameters used in polynomial matrix expressions
- """
- ...
+ # -- New API --
@property
- @abc.abstractmethod
- def offset_dict(self) -> dict[tuple[typing.Any], tuple[int, int]]:
- """
- a variable consists of one or more parameters indexed by a start
- and an end index
- """
- # NP: I call a thing (start, end) a _range_ or _interval_ of indices to index multiple varilables
- # NP: offset_dict is confusing IMHO, consider renaming
- ...
-
+ @abstractmethod
+ def n_variables(self) -> int:
+ """ Number of polynomial variables """
+
@property
- @abc.abstractmethod
- def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:
- # NP: TODO explanation of how auxiliary equaitons work
- ...
+ @abstractmethod
+ def indices(self) -> dict[VariableMixin, IndexRange]:
+ """ Map from variable objects to their indices. """
- # NP: get a variable name from the offset, hovever you can only ever
- # NP: get names using offsets, so maybe rename to something more intutive like
- # FIXME: rename to get_variable_name() or get_parameter_name()
- def get_name_from_offset(self, offset: int):
- for variable, (start, end) in self.offset_dict.items():
- if start <= offset < end:
- return f"{str(variable)}_{offset-start}"
+ def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]:
+ """ Get the index of a variable. """
+ # Check if already in there
+ if var in self.indices:
+ return self, self.indices[var]
- # NP: key does not mean anything for someone who does not know how this class works inside
- # FIXME: rename to just get_variable() or get_parameter()
- def get_key_from_offset(self, offset: int):
- for variable, (start, end) in self.offset_dict.items():
- if start <= offset < end:
- return variable
+ # If not save new index
+ size = prod(var.shape)
+ index = IndexRange(start=self.n_variables, end=self.n_variables + size)
- # NP: register a variable / parameter into the state object
- # NP: why are you allowed to not give a key (use case)? also, rename key to variable / parameter
- # NP: register() is good, other good names are index(), index_variable(), index_parameter()
- def register(
- self,
- n_param: int,
- key: typing.Any = None, # NP: Any is close to useless, specify type
- ) -> "ExpressionStateMixin":
- if key is None:
- updated_state = dataclasses.replace(
- self,
- n_param=self.n_param + n_param,
- )
+ return dataclasses.replace( # type: ignore[type-var]
+ self,
+ n_variables=self.n_variables + size,
+ indices={**self.indices, var: index}
+ ), index
+
+ def register(self, var: VariableMixin) -> ExpressionStateMixin:
+ """
+ Create an index for a variable, but does not return the index. If you
+ want the index use
+ :py:meth:`polymatrix.expressionstate.mixins.ExpressionStateMixin.index`
+ """
+ state, _ = self.index(var)
+ return state
- elif key not in self.offset_dict:
- updated_state = dataclasses.replace(
- self,
- offset_dict=self.offset_dict
- | {key: (self.n_param, self.n_param + n_param)},
- n_param=self.n_param + n_param,
- )
+ def get_name(self, index: int) -> str:
+ """ Get the name of a variable given its index. """
+ for variable, (start, end) in self.indices.items():
+ if start <= index < end:
+ # Variable is not scalar
+ if end - start > 1:
+ return f"{variable.name}_{index - start}"
- else:
- updated_state = self
+ return variable.name
- return updated_state
+ raise IndexError(f"There is no variable with index {index}.")