summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-05-08 11:53:28 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-05-08 11:53:28 +0200
commit0e5e050166a221712cb4047625652bf0210e7c61 (patch)
tree6ac2fc6a2911795ea1477f5f39b26429b4d1a659
parentInitial commit (diff)
downloadsumofsquares-0e5e050166a221712cb4047625652bf0210e7c61.tar.gz
sumofsquares-0e5e050166a221712cb4047625652bf0210e7c61.zip
allow sos expression to have more than 1 column
-rw-r--r--sumofsquares/init/initsosexpr.py6
-rw-r--r--sumofsquares/mixins/sosexpropmixin.py11
-rw-r--r--sumofsquares/sosexprbase/impl.py1
-rw-r--r--sumofsquares/sosexprbase/init/initsosexprbase.py15
-rw-r--r--sumofsquares/sosexprbase/mixins/exprfrommonommixin.py14
5 files changed, 39 insertions, 8 deletions
diff --git a/sumofsquares/init/initsosexpr.py b/sumofsquares/init/initsosexpr.py
index c858549..214bb0d 100644
--- a/sumofsquares/init/initsosexpr.py
+++ b/sumofsquares/init/initsosexpr.py
@@ -25,14 +25,16 @@ def init_param_expr(
name: str,
variables: polymatrix.Expression,
monom: polymatrix.Expression | None = None,
- n_rows: int | None = None,
+ n_row: int | None = None,
+ n_col: int | None = None,
):
return ParamSOSExprImpl(
underlying=init_param_sos_expr_base(
name=name,
monom=monom,
variables=variables,
- n_rows=n_rows,
+ n_row=n_row,
+ n_col=n_col,
),
)
diff --git a/sumofsquares/mixins/sosexpropmixin.py b/sumofsquares/mixins/sosexpropmixin.py
index a345284..2020043 100644
--- a/sumofsquares/mixins/sosexpropmixin.py
+++ b/sumofsquares/mixins/sosexpropmixin.py
@@ -131,3 +131,14 @@ class SOSExprOPMixin(SOSExprMixin):
dependence=self.dependence,
),
)
+
+ def set_variables(self, variables):
+ return dataclasses.replace(
+ self,
+ underlying=init_sos_expr_base(
+ expr=self.expr,
+ variables=variables,
+ dependence=self.dependence,
+ ),
+ )
+
diff --git a/sumofsquares/sosexprbase/impl.py b/sumofsquares/sosexprbase/impl.py
index 02f220f..5a3a800 100644
--- a/sumofsquares/sosexprbase/impl.py
+++ b/sumofsquares/sosexprbase/impl.py
@@ -19,3 +19,4 @@ class ParamSOSExprBaseImpl(ParamSOSExprBase):
monom: polymatrix.Expression
variables: polymatrix.Expression
param_matrix: polymatrix.Expression
+ n_row: int
diff --git a/sumofsquares/sosexprbase/init/initsosexprbase.py b/sumofsquares/sosexprbase/init/initsosexprbase.py
index 6b7f299..9dade7e 100644
--- a/sumofsquares/sosexprbase/init/initsosexprbase.py
+++ b/sumofsquares/sosexprbase/init/initsosexprbase.py
@@ -30,20 +30,24 @@ def init_param_sos_expr_base(
name: str,
variables: polymatrix.Expression,
monom: polymatrix.Expression | None = None,
- n_rows: int | None = None,
+ n_row: int | None = None,
+ n_col: int | None = None,
):
if monom is None:
monom = polymatrix.from_(1)
- if n_rows is None:
- n_rows = 1
+ if n_row == None:
+ n_row = 1
- if n_rows is 1:
+ if n_col == None:
+ n_col = 1
+
+ if n_row == 1 and n_col == 1:
param = monom.parametrize(f'{name}')
param_matrix = param.T
else:
- params = tuple(monom.parametrize(f'{name}_{row+1}') for row in range(n_rows))
+ params = tuple(monom.parametrize(f'{name}_{row+1}_{col+1}') for col in range(n_col) for row in range(n_row))
param = polymatrix.v_stack(params)
param_matrix = polymatrix.v_stack(tuple(param.T for param in params))
@@ -53,4 +57,5 @@ def init_param_sos_expr_base(
monom=monom,
variables=variables,
param_matrix=param_matrix,
+ n_row=n_row,
)
diff --git a/sumofsquares/sosexprbase/mixins/exprfrommonommixin.py b/sumofsquares/sosexprbase/mixins/exprfrommonommixin.py
index 354f46f..6e6e80b 100644
--- a/sumofsquares/sosexprbase/mixins/exprfrommonommixin.py
+++ b/sumofsquares/sosexprbase/mixins/exprfrommonommixin.py
@@ -12,5 +12,17 @@ class ExprFromMonomMixin(ParameterMixin, ExprBaseMixin):
...
@property
+ @abc.abstractmethod
+ def n_row(self) -> int:
+ ...
+
+ @property
def expr(self) -> polymatrix.Expression:
- return (self.param_matrix @ self.monom).cache()
+ expr_vec = (self.param_matrix @ self.monom).cache()
+
+ if self.n_row == 1:
+ return expr_vec
+
+ else:
+ return expr_vec.reshape(self.n_row, -1)
+