diff options
Diffstat (limited to 'sumofsquares/init/initsosexpr.py')
-rw-r--r-- | sumofsquares/init/initsosexpr.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/sumofsquares/init/initsosexpr.py b/sumofsquares/init/initsosexpr.py new file mode 100644 index 0000000..c858549 --- /dev/null +++ b/sumofsquares/init/initsosexpr.py @@ -0,0 +1,83 @@ +import polymatrix +from sumofsquares.abc.sosexpr import SOSExpr + +from sumofsquares.impl.sosexprimpl import ParamSOSExprImpl, SOSExprImpl +from sumofsquares.mixins.sosexprmixin import SOSExprMixin +from sumofsquares.sosexprbase.init.initsosexprbase import init_param_sos_expr_base, init_sos_expr_base +from sumofsquares.sosexprbase.mixins.parametermixin import ParameterMixin + + +def init_sos_expr( + expr: polymatrix.Expression, + variables: polymatrix.Expression, + dependence: tuple[ParameterMixin], +): + return SOSExprImpl( + underlying=init_sos_expr_base( + expr=expr, + variables=variables, + dependence=dependence, + ), + ) + + +def init_param_expr( + name: str, + variables: polymatrix.Expression, + monom: polymatrix.Expression | None = None, + n_rows: int | None = None, +): + return ParamSOSExprImpl( + underlying=init_param_sos_expr_base( + name=name, + monom=monom, + variables=variables, + n_rows=n_rows, + ), + ) + + +def init_param_expr_from_reference( + name: str, + reference: SOSExpr, + # variables: polymatrix.Expression, + multiplicand: SOSExpr | polymatrix.Expression | None = None, +): + variables = reference.variables + + if multiplicand is None: + multiplicand_expr = polymatrix.from_(1) + + elif isinstance(multiplicand, polymatrix.Expression): + multiplicand_expr = multiplicand + + elif isinstance(multiplicand, SOSExprMixin): + assert multiplicand.variables == variables, f'{multiplicand.variables=}, {variables=}' + + multiplicand_expr = multiplicand.expr + + else: + multiplicand_expr = polymatrix.from_(multiplicand) + + m_sos_monom = multiplicand_expr.quadratic_monomials(variables) + + max_degree = m_sos_monom.max_degree().T.max() + + m_max_monom = m_sos_monom.filter( + m_sos_monom.max_degree() - max_degree, + inverse=True, + ) + + sos_monom = reference.expr.quadratic_monomials(variables).subtract_monomials(m_max_monom) + + expr = (sos_monom @ sos_monom.T).reshape(1, -1).sum() + + monom = expr.linear_monomials(variables).cache() + + return init_param_expr( + name=name, + monom=monom, + variables=variables, + ) + + |