summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-11 17:59:28 +0200
committerNao Pross <np@0hm.ch>2024-05-11 17:59:28 +0200
commitf9ab663a2fdb05f670f8fe0efffe8bce0c2a08e6 (patch)
tree9ebf6eed8f1386f6f754513a68c3277a91ee861c
parentAdd more helper methods to ExpressionState (diff)
downloadpolymatrix-f9ab663a2fdb05f670f8fe0efffe8bce0c2a08e6.tar.gz
polymatrix-f9ab663a2fdb05f670f8fe0efffe8bce0c2a08e6.zip
Fix ParametrizeExpr, was broken since changes in ExpressionState
Commit that broke it: 13f11cb60021d4c143ce8c80e9b0c5027a4bf434
-rw-r--r--polymatrix/expression/mixins/parametrizeexprmixin.py81
1 files changed, 40 insertions, 41 deletions
diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py
index 290b2eb..5762480 100644
--- a/polymatrix/expression/mixins/parametrizeexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizeexprmixin.py
@@ -1,13 +1,30 @@
import abc
import dataclasses
-from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
+from polymatrix.polymatrix.index import PolyMatrixDict, MatrixIndex, PolyDict, MonomialIndex, VariableIndex
+from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.variable.init import init_variable
class ParametrizeExprMixin(ExpressionBaseMixin):
+ r"""
+ Given a column vector (or an expression that evaluates to a column vector)
+ :math:`x \in \mathbf{R}^n` and a name, for instance :math:`u`, create a new
+ vector of variables called :math:`u` of the same size as :math:`x`.
+
+ This is useful if you want to create coefficients, for example:
+
+ ::
+ x = polymatrix.v_stack((x_0, x_1, x_2)) # in R^3
+ u = x.parametrize('u')
+
+ # Create u_0 * x_0 + u_1 * x_1 + u_2 * x_2
+ u.T @ x
+ """
+
@property
@abc.abstractclassmethod
def underlying(self) -> ExpressionBaseMixin: ...
@@ -17,49 +34,31 @@ class ParametrizeExprMixin(ExpressionBaseMixin):
def name(self) -> str: ...
# overwrites the abstract method of `ExpressionBaseMixin`
- def apply(
- self,
- state: ExpressionStateMixin,
- ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+ def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
state, underlying = self.underlying.apply(state)
- assert underlying.shape[1] == 1
+ nrows, ncols = underlying.shape
+ if ncols != 1:
+ raise ValueError("Parametrize works only with column vectors")
- if self.name in state.offset_dict:
+ # FIXME: not sure this behaviour is intuitive, discuss
+ if v := state.get_variable_from_name_or(self.name, if_not_present=False):
start, end = state.offset_dict[self.name]
-
- assert underlying.shape[0] == (end - start)
+ if nrows != (end - start):
+ raise ValueError("Cannot parametrize {self.underlying} with variable {v} "
+ "found in state object, because its shape {(nrow, ncols)} "
+ "does not match ({self.underlying.shape}). ")
else:
- start = state.n_param
-
- state = state.register(
- key=self.name,
- n_param=underlying.shape[0],
- )
-
- # # cache polymatrix to not re-parametrize at every apply call
- # if self in state.cache:
- # return state, state.cache[self]
-
- # state, underlying = self.underlying.apply(state)
-
- # assert underlying.shape[1] == 1
-
- poly_matrix_data = {}
-
- for row in range(underlying.shape[0]):
- var_index = start + row
- poly_matrix_data[row, 0] = {((var_index, 1),): 1}
-
- poly_matrix = init_poly_matrix(
- data=poly_matrix_data,
- shape=underlying.shape,
- )
-
- # state = dataclasses.replace(
- # state,
- # cache=state.cache | {self: poly_matrix},
- # )
-
- return state, poly_matrix
+ v = init_variable(self.name, shape=(nrows, 1))
+ state = state.register(v)
+
+ p = PolyMatrixDict({
+ MatrixIndex(row, 0): PolyDict({
+ MonomialIndex((VariableIndex(index, 1),)): 1
+ })
+ for row, index in enumerate(state.get_indices(v))
+ })
+
+ return state, init_poly_matrix(p, underlying.shape)