summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-17 17:27:48 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-17 17:27:48 +0200
commit8341b1fae4454c1305a615b8069fcda42042a0cb (patch)
treec456859ead1ab4bf2d87e61b771570abcdc1d9f2
parentorder methods in alphabetical order (diff)
downloadpolymatrix-8341b1fae4454c1305a615b8069fcda42042a0cb.tar.gz
polymatrix-8341b1fae4454c1305a615b8069fcda42042a0cb.zip
'parametrize' method accepts any size matrix, not only vectors
-rw-r--r--polymatrix/expression/mixins/parametrizeexprmixin.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py
index b9673e4..b8054ef 100644
--- a/polymatrix/expression/mixins/parametrizeexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizeexprmixin.py
@@ -31,14 +31,13 @@ class ParametrizeExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state)
- assert underlying.shape[1] == 1
-
start_index = state.n_param
terms = {}
- for row in range(underlying.shape[0]):
- var_index = start_index + row
- terms[row, 0] = {((var_index, 1),): 1.0}
+ for col in range(underlying.shape[1]):
+ for row in range(underlying.shape[0]):
+ var_index = start_index + col * underlying.shape[1] + row
+ terms[row, col] = {((var_index, 1),): 1.0}
state = state.register(
key=self,