diff options
Diffstat (limited to 'sumofsquares/cvxopt.py')
-rw-r--r-- | sumofsquares/cvxopt.py | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/sumofsquares/cvxopt.py b/sumofsquares/cvxopt.py index 93346c3..ec31892 100644 --- a/sumofsquares/cvxopt.py +++ b/sumofsquares/cvxopt.py @@ -4,8 +4,7 @@ import polymatrix import numpy as np import math -from sumofsquares.abc import ParamSOSExpr -from sumofsquares.sosconstraint.abc import SOSConstraint +from sumofsquares.sosexpr.abc import ParamSOSExpr, SOSExpr @dataclasses.dataclass @@ -348,18 +347,17 @@ def solve_cone2( def solve_sos_problem( cost: tuple[polymatrix.Expression], - sos_constraints: tuple[SOSConstraint], + sos_constraints: tuple[SOSExpr], state: polymatrix.ExpressionState, - free_param: tuple[ParamSOSExpr] = None, - x0: dict[ParamSOSExpr, np.ndarray] = None, + free_param: tuple[ParamSOSExpr] | None = None, + x0: dict[ParamSOSExpr, np.ndarray] | None = None, ): if x0 is None: x0 = {} def gen_all_param_expr(): for sos_constraint in sos_constraints: - for param_expr in sos_constraint.dependence: - yield param_expr + yield from sos_constraint.dependence all_param_expr = tuple(set(gen_all_param_expr())) @@ -371,7 +369,7 @@ def solve_sos_problem( def gen_inequality(): for sos_constraint in sos_constraints: if any(param_expr in free_param for param_expr in sos_constraint.dependence): - yield sos_constraint.constraint.eval(sub_vals) + yield sos_constraint.sos_matrix_vec.eval(sub_vals) inequality = tuple(gen_inequality()) @@ -387,10 +385,10 @@ def solve_sos_problem( def solve_sos_problem2( cost: tuple[polymatrix.Expression], - sos_constraints: tuple[SOSConstraint], + sos_constraints: tuple[SOSExpr], state: polymatrix.ExpressionState, - subs: tuple[ParamSOSExpr] = None, - x0: dict[ParamSOSExpr, np.ndarray] = None, + subs: tuple[ParamSOSExpr] | None = None, + x0: dict[ParamSOSExpr, np.ndarray] | None = None, print_info = False, ): if x0 is None: @@ -414,7 +412,7 @@ def solve_sos_problem2( def gen_inequality(): for sos_constraint in sos_constraints: if any(param_expr in free_param_expr for param_expr in sos_constraint.dependence): - yield sos_constraint.constraint.eval(sub_vals) + yield sos_constraint.sos_matrix_vec.eval(sub_vals) inequality = tuple(gen_inequality()) |