summaryrefslogtreecommitdiffstats
path: root/sumofsquares/init/initsosexpr.py
diff options
context:
space:
mode:
Diffstat (limited to 'sumofsquares/init/initsosexpr.py')
-rw-r--r--sumofsquares/init/initsosexpr.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/sumofsquares/init/initsosexpr.py b/sumofsquares/init/initsosexpr.py
new file mode 100644
index 0000000..c858549
--- /dev/null
+++ b/sumofsquares/init/initsosexpr.py
@@ -0,0 +1,83 @@
+import polymatrix
+from sumofsquares.abc.sosexpr import SOSExpr
+
+from sumofsquares.impl.sosexprimpl import ParamSOSExprImpl, SOSExprImpl
+from sumofsquares.mixins.sosexprmixin import SOSExprMixin
+from sumofsquares.sosexprbase.init.initsosexprbase import init_param_sos_expr_base, init_sos_expr_base
+from sumofsquares.sosexprbase.mixins.parametermixin import ParameterMixin
+
+
+def init_sos_expr(
+ expr: polymatrix.Expression,
+ variables: polymatrix.Expression,
+ dependence: tuple[ParameterMixin],
+):
+ return SOSExprImpl(
+ underlying=init_sos_expr_base(
+ expr=expr,
+ variables=variables,
+ dependence=dependence,
+ ),
+ )
+
+
+def init_param_expr(
+ name: str,
+ variables: polymatrix.Expression,
+ monom: polymatrix.Expression | None = None,
+ n_rows: int | None = None,
+):
+ return ParamSOSExprImpl(
+ underlying=init_param_sos_expr_base(
+ name=name,
+ monom=monom,
+ variables=variables,
+ n_rows=n_rows,
+ ),
+ )
+
+
+def init_param_expr_from_reference(
+ name: str,
+ reference: SOSExpr,
+ # variables: polymatrix.Expression,
+ multiplicand: SOSExpr | polymatrix.Expression | None = None,
+):
+ variables = reference.variables
+
+ if multiplicand is None:
+ multiplicand_expr = polymatrix.from_(1)
+
+ elif isinstance(multiplicand, polymatrix.Expression):
+ multiplicand_expr = multiplicand
+
+ elif isinstance(multiplicand, SOSExprMixin):
+ assert multiplicand.variables == variables, f'{multiplicand.variables=}, {variables=}'
+
+ multiplicand_expr = multiplicand.expr
+
+ else:
+ multiplicand_expr = polymatrix.from_(multiplicand)
+
+ m_sos_monom = multiplicand_expr.quadratic_monomials(variables)
+
+ max_degree = m_sos_monom.max_degree().T.max()
+
+ m_max_monom = m_sos_monom.filter(
+ m_sos_monom.max_degree() - max_degree,
+ inverse=True,
+ )
+
+ sos_monom = reference.expr.quadratic_monomials(variables).subtract_monomials(m_max_monom)
+
+ expr = (sos_monom @ sos_monom.T).reshape(1, -1).sum()
+
+ monom = expr.linear_monomials(variables).cache()
+
+ return init_param_expr(
+ name=name,
+ monom=monom,
+ variables=variables,
+ )
+
+