diff options
-rw-r--r-- | polymatrix/expressionstate/impl.py | 5 | ||||
-rw-r--r-- | polymatrix/expressionstate/init.py | 18 | ||||
-rw-r--r-- | polymatrix/expressionstate/mixins.py | 113 |
3 files changed, 58 insertions, 78 deletions
diff --git a/polymatrix/expressionstate/impl.py b/polymatrix/expressionstate/impl.py index c25dae5..c243a26 100644 --- a/polymatrix/expressionstate/impl.py +++ b/polymatrix/expressionstate/impl.py @@ -5,7 +5,6 @@ from polymatrix.expressionstate.abc import ExpressionState @dataclassabc.dataclassabc(frozen=True) class ExpressionStateImpl(ExpressionState): - n_param: int - offset_dict: dict - auxillary_equations: dict[int, dict[tuple[int], float]] + n_variables: int + indices: dict cache: dict diff --git a/polymatrix/expressionstate/init.py b/polymatrix/expressionstate/init.py index 8fdbabf..d7c7d25 100644 --- a/polymatrix/expressionstate/init.py +++ b/polymatrix/expressionstate/init.py @@ -1,21 +1,9 @@ from polymatrix.expressionstate.impl import ExpressionStateImpl -def init_expression_state( - n_param: int = None, - offset_dict: dict = None, -): - # FIXME: just set the defaults above instead of None which btw is not - # allowed by the type checker ("implicit none is not allowed") - if n_param is None: - n_param = 0 - - if offset_dict is None: - offset_dict = {} - +def init_expression_state(n_param: int = 0, offset_dict: dict = {}): return ExpressionStateImpl( - n_param=n_param, - offset_dict=offset_dict, - auxillary_equations={}, + n_variables=n_param, + indices={}, cache={}, ) diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index 7122e80..ffaea0a 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -1,77 +1,70 @@ -import abc +from __future__ import annotations +from typing import TYPE_CHECKING + +from abc import abstractmethod +from typing import NamedTuple +from math import prod import dataclasses -import typing + +if TYPE_CHECKING: + from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.statemonad.mixins import StateCacheMixin +# TODO: move to typing submodule +class IndexRange(NamedTuple): + start: int + end: int + -# NP: "state" of an expression that maps indices to variable / parameter objects class ExpressionStateMixin( StateCacheMixin, ): - @property - @abc.abstractmethod - def n_param(self) -> int: - """ - number of parameters used in polynomial matrix expressions - """ - ... + # -- New API -- @property - @abc.abstractmethod - def offset_dict(self) -> dict[tuple[typing.Any], tuple[int, int]]: - """ - a variable consists of one or more parameters indexed by a start - and an end index - """ - # NP: I call a thing (start, end) a _range_ or _interval_ of indices to index multiple varilables - # NP: offset_dict is confusing IMHO, consider renaming - ... - + @abstractmethod + def n_variables(self) -> int: + """ Number of polynomial variables """ + @property - @abc.abstractmethod - def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]: - # NP: TODO explanation of how auxiliary equaitons work - ... + @abstractmethod + def indices(self) -> dict[VariableMixin, IndexRange]: + """ Map from variable objects to their indices. """ - # NP: get a variable name from the offset, hovever you can only ever - # NP: get names using offsets, so maybe rename to something more intutive like - # FIXME: rename to get_variable_name() or get_parameter_name() - def get_name_from_offset(self, offset: int): - for variable, (start, end) in self.offset_dict.items(): - if start <= offset < end: - return f"{str(variable)}_{offset-start}" + def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]: + """ Get the index of a variable. """ + # Check if already in there + if var in self.indices: + return self, self.indices[var] - # NP: key does not mean anything for someone who does not know how this class works inside - # FIXME: rename to just get_variable() or get_parameter() - def get_key_from_offset(self, offset: int): - for variable, (start, end) in self.offset_dict.items(): - if start <= offset < end: - return variable + # If not save new index + size = prod(var.shape) + index = IndexRange(start=self.n_variables, end=self.n_variables + size) - # NP: register a variable / parameter into the state object - # NP: why are you allowed to not give a key (use case)? also, rename key to variable / parameter - # NP: register() is good, other good names are index(), index_variable(), index_parameter() - def register( - self, - n_param: int, - key: typing.Any = None, # NP: Any is close to useless, specify type - ) -> "ExpressionStateMixin": - if key is None: - updated_state = dataclasses.replace( - self, - n_param=self.n_param + n_param, - ) + return dataclasses.replace( # type: ignore[type-var] + self, + n_variables=self.n_variables + size, + indices={**self.indices, var: index} + ), index + + def register(self, var: VariableMixin) -> ExpressionStateMixin: + """ + Create an index for a variable, but does not return the index. If you + want the index use + :py:meth:`polymatrix.expressionstate.mixins.ExpressionStateMixin.index` + """ + state, _ = self.index(var) + return state - elif key not in self.offset_dict: - updated_state = dataclasses.replace( - self, - offset_dict=self.offset_dict - | {key: (self.n_param, self.n_param + n_param)}, - n_param=self.n_param + n_param, - ) + def get_name(self, index: int) -> str: + """ Get the name of a variable given its index. """ + for variable, (start, end) in self.indices.items(): + if start <= index < end: + # Variable is not scalar + if end - start > 1: + return f"{variable.name}_{index - start}" - else: - updated_state = self + return variable.name - return updated_state + raise IndexError(f"There is no variable with index {index}.") |