summaryrefslogtreecommitdiffstats
path: root/polymatrix/expressionstate/mixins.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expressionstate/mixins.py')
-rw-r--r--polymatrix/expressionstate/mixins.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
new file mode 100644
index 0000000..538f652
--- /dev/null
+++ b/polymatrix/expressionstate/mixins.py
@@ -0,0 +1,68 @@
+import abc
+import dataclasses
+import typing
+
+from polymatrix.statemonad.mixins.statemixin import StateCacheMixin
+
+
+class ExpressionStateMixin(
+ StateCacheMixin,
+):
+
+ @property
+ @abc.abstractmethod
+ def n_param(self) -> int:
+ """
+ number of parameters used in polynomial matrix expressions
+ """
+
+ ...
+
+ @property
+ @abc.abstractmethod
+ def offset_dict(self) -> dict[tuple[typing.Any], tuple[int, int]]:
+ """
+ a variable consists of one or more parameters indexed by a start
+ and an end index
+ """
+
+ ...
+
+ @property
+ @abc.abstractmethod
+ def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:
+ ...
+
+ def get_name_from_offset(self, offset: int):
+ for variable, (start, end) in self.offset_dict.items():
+ if start <= offset < end:
+ return f'{str(variable)}_{offset-start}'
+
+ def get_key_from_offset(self, offset: int):
+ for variable, (start, end) in self.offset_dict.items():
+ if start <= offset < end:
+ return variable
+
+ def register(
+ self,
+ n_param: int,
+ key: typing.Any = None,
+ ) -> 'ExpressionStateMixin':
+
+ if key is None:
+ updated_state = dataclasses.replace(
+ self,
+ n_param=self.n_param + n_param,
+ )
+
+ elif key not in self.offset_dict:
+ updated_state = dataclasses.replace(
+ self,
+ offset_dict=self.offset_dict | {key: (self.n_param, self.n_param + n_param)},
+ n_param=self.n_param + n_param,
+ )
+
+ else:
+ updated_state = self
+
+ return updated_state