summaryrefslogtreecommitdiffstats
path: root/sumofsquares/cvxopt.py
diff options
context:
space:
mode:
Diffstat (limited to 'sumofsquares/cvxopt.py')
-rw-r--r--sumofsquares/cvxopt.py22
1 files changed, 10 insertions, 12 deletions
diff --git a/sumofsquares/cvxopt.py b/sumofsquares/cvxopt.py
index 93346c3..ec31892 100644
--- a/sumofsquares/cvxopt.py
+++ b/sumofsquares/cvxopt.py
@@ -4,8 +4,7 @@ import polymatrix
import numpy as np
import math
-from sumofsquares.abc import ParamSOSExpr
-from sumofsquares.sosconstraint.abc import SOSConstraint
+from sumofsquares.sosexpr.abc import ParamSOSExpr, SOSExpr
@dataclasses.dataclass
@@ -348,18 +347,17 @@ def solve_cone2(
def solve_sos_problem(
cost: tuple[polymatrix.Expression],
- sos_constraints: tuple[SOSConstraint],
+ sos_constraints: tuple[SOSExpr],
state: polymatrix.ExpressionState,
- free_param: tuple[ParamSOSExpr] = None,
- x0: dict[ParamSOSExpr, np.ndarray] = None,
+ free_param: tuple[ParamSOSExpr] | None = None,
+ x0: dict[ParamSOSExpr, np.ndarray] | None = None,
):
if x0 is None:
x0 = {}
def gen_all_param_expr():
for sos_constraint in sos_constraints:
- for param_expr in sos_constraint.dependence:
- yield param_expr
+ yield from sos_constraint.dependence
all_param_expr = tuple(set(gen_all_param_expr()))
@@ -371,7 +369,7 @@ def solve_sos_problem(
def gen_inequality():
for sos_constraint in sos_constraints:
if any(param_expr in free_param for param_expr in sos_constraint.dependence):
- yield sos_constraint.constraint.eval(sub_vals)
+ yield sos_constraint.sos_matrix_vec.eval(sub_vals)
inequality = tuple(gen_inequality())
@@ -387,10 +385,10 @@ def solve_sos_problem(
def solve_sos_problem2(
cost: tuple[polymatrix.Expression],
- sos_constraints: tuple[SOSConstraint],
+ sos_constraints: tuple[SOSExpr],
state: polymatrix.ExpressionState,
- subs: tuple[ParamSOSExpr] = None,
- x0: dict[ParamSOSExpr, np.ndarray] = None,
+ subs: tuple[ParamSOSExpr] | None = None,
+ x0: dict[ParamSOSExpr, np.ndarray] | None = None,
print_info = False,
):
if x0 is None:
@@ -414,7 +412,7 @@ def solve_sos_problem2(
def gen_inequality():
for sos_constraint in sos_constraints:
if any(param_expr in free_param_expr for param_expr in sos_constraint.dependence):
- yield sos_constraint.constraint.eval(sub_vals)
+ yield sos_constraint.sos_matrix_vec.eval(sub_vals)
inequality = tuple(gen_inequality())