diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-08-17 17:27:48 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-08-17 17:27:48 +0200 |
commit | 8341b1fae4454c1305a615b8069fcda42042a0cb (patch) | |
tree | c456859ead1ab4bf2d87e61b771570abcdc1d9f2 | |
parent | order methods in alphabetical order (diff) | |
download | polymatrix-8341b1fae4454c1305a615b8069fcda42042a0cb.tar.gz polymatrix-8341b1fae4454c1305a615b8069fcda42042a0cb.zip |
'parametrize' method accepts any size matrix, not only vectors
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/mixins/parametrizeexprmixin.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index b9673e4..b8054ef 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -31,14 +31,13 @@ class ParametrizeExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state) - assert underlying.shape[1] == 1 - start_index = state.n_param terms = {} - for row in range(underlying.shape[0]): - var_index = start_index + row - terms[row, 0] = {((var_index, 1),): 1.0} + for col in range(underlying.shape[1]): + for row in range(underlying.shape[0]): + var_index = start_index + col * underlying.shape[1] + row + terms[row, col] = {((var_index, 1),): 1.0} state = state.register( key=self, |