From f9ab663a2fdb05f670f8fe0efffe8bce0c2a08e6 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 11 May 2024 17:59:28 +0200 Subject: Fix ParametrizeExpr, was broken since changes in ExpressionState Commit that broke it: 13f11cb60021d4c143ce8c80e9b0c5027a4bf434 --- .../expression/mixins/parametrizeexprmixin.py | 81 +++++++++++----------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index 290b2eb..5762480 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -1,13 +1,30 @@ import abc import dataclasses -from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.polymatrix.index import PolyMatrixDict, MatrixIndex, PolyDict, MonomialIndex, VariableIndex +from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.variable.init import init_variable class ParametrizeExprMixin(ExpressionBaseMixin): + r""" + Given a column vector (or an expression that evaluates to a column vector) + :math:`x \in \mathbf{R}^n` and a name, for instance :math:`u`, create a new + vector of variables called :math:`u` of the same size as :math:`x`. + + This is useful if you want to create coefficients, for example: + + :: + x = polymatrix.v_stack((x_0, x_1, x_2)) # in R^3 + u = x.parametrize('u') + + # Create u_0 * x_0 + u_1 * x_1 + u_2 * x_2 + u.T @ x + """ + @property @abc.abstractclassmethod def underlying(self) -> ExpressionBaseMixin: ... @@ -17,49 +34,31 @@ class ParametrizeExprMixin(ExpressionBaseMixin): def name(self) -> str: ... # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state, underlying = self.underlying.apply(state) - assert underlying.shape[1] == 1 + nrows, ncols = underlying.shape + if ncols != 1: + raise ValueError("Parametrize works only with column vectors") - if self.name in state.offset_dict: + # FIXME: not sure this behaviour is intuitive, discuss + if v := state.get_variable_from_name_or(self.name, if_not_present=False): start, end = state.offset_dict[self.name] - - assert underlying.shape[0] == (end - start) + if nrows != (end - start): + raise ValueError("Cannot parametrize {self.underlying} with variable {v} " + "found in state object, because its shape {(nrow, ncols)} " + "does not match ({self.underlying.shape}). ") else: - start = state.n_param - - state = state.register( - key=self.name, - n_param=underlying.shape[0], - ) - - # # cache polymatrix to not re-parametrize at every apply call - # if self in state.cache: - # return state, state.cache[self] - - # state, underlying = self.underlying.apply(state) - - # assert underlying.shape[1] == 1 - - poly_matrix_data = {} - - for row in range(underlying.shape[0]): - var_index = start + row - poly_matrix_data[row, 0] = {((var_index, 1),): 1} - - poly_matrix = init_poly_matrix( - data=poly_matrix_data, - shape=underlying.shape, - ) - - # state = dataclasses.replace( - # state, - # cache=state.cache | {self: poly_matrix}, - # ) - - return state, poly_matrix + v = init_variable(self.name, shape=(nrows, 1)) + state = state.register(v) + + p = PolyMatrixDict({ + MatrixIndex(row, 0): PolyDict({ + MonomialIndex((VariableIndex(index, 1),)): 1 + }) + for row, index in enumerate(state.get_indices(v)) + }) + + return state, init_poly_matrix(p, underlying.shape) -- cgit v1.2.1