diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/mixins/parametrizematrixexprmixin.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/parametrizematrixexprmixin.py b/polymatrix/expression/mixins/parametrizematrixexprmixin.py new file mode 100644 index 0000000..f59ddb7 --- /dev/null +++ b/polymatrix/expression/mixins/parametrizematrixexprmixin.py @@ -0,0 +1,62 @@ + +import abc +import dataclasses + +from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin +from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin + + +# rename ParametrizeSymmetricMatrixExprMixin +class ParametrizeMatrixExprMixin(ExpressionBaseMixin): + @property + @abc.abstractclassmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractclassmethod + def name(self) -> str: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionStateMixin, + ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + + # cache polymatrix to not re-parametrize at every apply call + if self in state.cache: + return state, state.cache[self] + + state, underlying = self.underlying.apply(state) + + assert underlying.shape[1] == 1 + + terms = {} + var_index = 0 + + for row in range(underlying.shape[0]): + for _ in range(row, underlying.shape[0]): + + terms[var_index, 0] = {((state.n_param + var_index, 1),): 1.0} + + var_index += 1 + + state = state.register( + key=self, + n_param=var_index, + ) + + poly_matrix = init_poly_matrix( + terms=terms, + shape=(var_index, 1), + ) + + state = dataclasses.replace( + state, + cache=state.cache | {self: poly_matrix}, + ) + + return state, poly_matrix |