summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/mixins/parametrizematrixexprmixin.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/parametrizematrixexprmixin.py b/polymatrix/expression/mixins/parametrizematrixexprmixin.py
new file mode 100644
index 0000000..f59ddb7
--- /dev/null
+++ b/polymatrix/expression/mixins/parametrizematrixexprmixin.py
@@ -0,0 +1,62 @@
+
+import abc
+import dataclasses
+
+from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
+
+
+# rename ParametrizeSymmetricMatrixExprMixin
+class ParametrizeMatrixExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractclassmethod
+ def name(self) -> str:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ # cache polymatrix to not re-parametrize at every apply call
+ if self in state.cache:
+ return state, state.cache[self]
+
+ state, underlying = self.underlying.apply(state)
+
+ assert underlying.shape[1] == 1
+
+ terms = {}
+ var_index = 0
+
+ for row in range(underlying.shape[0]):
+ for _ in range(row, underlying.shape[0]):
+
+ terms[var_index, 0] = {((state.n_param + var_index, 1),): 1.0}
+
+ var_index += 1
+
+ state = state.register(
+ key=self,
+ n_param=var_index,
+ )
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=(var_index, 1),
+ )
+
+ state = dataclasses.replace(
+ state,
+ cache=state.cache | {self: poly_matrix},
+ )
+
+ return state, poly_matrix