diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-05-08 11:53:28 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-05-08 11:53:28 +0200 |
commit | 0e5e050166a221712cb4047625652bf0210e7c61 (patch) | |
tree | 6ac2fc6a2911795ea1477f5f39b26429b4d1a659 | |
parent | Initial commit (diff) | |
download | sumofsquares-0e5e050166a221712cb4047625652bf0210e7c61.tar.gz sumofsquares-0e5e050166a221712cb4047625652bf0210e7c61.zip |
allow sos expression to have more than 1 column
-rw-r--r-- | sumofsquares/init/initsosexpr.py | 6 | ||||
-rw-r--r-- | sumofsquares/mixins/sosexpropmixin.py | 11 | ||||
-rw-r--r-- | sumofsquares/sosexprbase/impl.py | 1 | ||||
-rw-r--r-- | sumofsquares/sosexprbase/init/initsosexprbase.py | 15 | ||||
-rw-r--r-- | sumofsquares/sosexprbase/mixins/exprfrommonommixin.py | 14 |
5 files changed, 39 insertions, 8 deletions
diff --git a/sumofsquares/init/initsosexpr.py b/sumofsquares/init/initsosexpr.py index c858549..214bb0d 100644 --- a/sumofsquares/init/initsosexpr.py +++ b/sumofsquares/init/initsosexpr.py @@ -25,14 +25,16 @@ def init_param_expr( name: str, variables: polymatrix.Expression, monom: polymatrix.Expression | None = None, - n_rows: int | None = None, + n_row: int | None = None, + n_col: int | None = None, ): return ParamSOSExprImpl( underlying=init_param_sos_expr_base( name=name, monom=monom, variables=variables, - n_rows=n_rows, + n_row=n_row, + n_col=n_col, ), ) diff --git a/sumofsquares/mixins/sosexpropmixin.py b/sumofsquares/mixins/sosexpropmixin.py index a345284..2020043 100644 --- a/sumofsquares/mixins/sosexpropmixin.py +++ b/sumofsquares/mixins/sosexpropmixin.py @@ -131,3 +131,14 @@ class SOSExprOPMixin(SOSExprMixin): dependence=self.dependence, ), ) + + def set_variables(self, variables): + return dataclasses.replace( + self, + underlying=init_sos_expr_base( + expr=self.expr, + variables=variables, + dependence=self.dependence, + ), + ) + diff --git a/sumofsquares/sosexprbase/impl.py b/sumofsquares/sosexprbase/impl.py index 02f220f..5a3a800 100644 --- a/sumofsquares/sosexprbase/impl.py +++ b/sumofsquares/sosexprbase/impl.py @@ -19,3 +19,4 @@ class ParamSOSExprBaseImpl(ParamSOSExprBase): monom: polymatrix.Expression variables: polymatrix.Expression param_matrix: polymatrix.Expression + n_row: int diff --git a/sumofsquares/sosexprbase/init/initsosexprbase.py b/sumofsquares/sosexprbase/init/initsosexprbase.py index 6b7f299..9dade7e 100644 --- a/sumofsquares/sosexprbase/init/initsosexprbase.py +++ b/sumofsquares/sosexprbase/init/initsosexprbase.py @@ -30,20 +30,24 @@ def init_param_sos_expr_base( name: str, variables: polymatrix.Expression, monom: polymatrix.Expression | None = None, - n_rows: int | None = None, + n_row: int | None = None, + n_col: int | None = None, ): if monom is None: monom = polymatrix.from_(1) - if n_rows is None: - n_rows = 1 + if n_row == None: + n_row = 1 - if n_rows is 1: + if n_col == None: + n_col = 1 + + if n_row == 1 and n_col == 1: param = monom.parametrize(f'{name}') param_matrix = param.T else: - params = tuple(monom.parametrize(f'{name}_{row+1}') for row in range(n_rows)) + params = tuple(monom.parametrize(f'{name}_{row+1}_{col+1}') for col in range(n_col) for row in range(n_row)) param = polymatrix.v_stack(params) param_matrix = polymatrix.v_stack(tuple(param.T for param in params)) @@ -53,4 +57,5 @@ def init_param_sos_expr_base( monom=monom, variables=variables, param_matrix=param_matrix, + n_row=n_row, ) diff --git a/sumofsquares/sosexprbase/mixins/exprfrommonommixin.py b/sumofsquares/sosexprbase/mixins/exprfrommonommixin.py index 354f46f..6e6e80b 100644 --- a/sumofsquares/sosexprbase/mixins/exprfrommonommixin.py +++ b/sumofsquares/sosexprbase/mixins/exprfrommonommixin.py @@ -12,5 +12,17 @@ class ExprFromMonomMixin(ParameterMixin, ExprBaseMixin): ... @property + @abc.abstractmethod + def n_row(self) -> int: + ... + + @property def expr(self) -> polymatrix.Expression: - return (self.param_matrix @ self.monom).cache() + expr_vec = (self.param_matrix @ self.monom).cache() + + if self.n_row == 1: + return expr_vec + + else: + return expr_vec.reshape(self.n_row, -1) + |