summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-28 12:09:47 +0200
committerNao Pross <np@0hm.ch>2024-05-28 12:09:47 +0200
commit7124cdd84745d9bbf56f9567b3b5f2b116f8ef53 (patch)
tree5e49dfdfc04196b1a3d0a1d6c986862913728400
parentImplement ExponentialCone in SOSProblem.apply, clean up comments (diff)
downloadsumofsquares-7124cdd84745d9bbf56f9567b3b5f2b116f8ef53.tar.gz
sumofsquares-7124cdd84745d9bbf56f9567b3b5f2b116f8ef53.zip
Update (empty) MOSEK and SCS interfaces
-rw-r--r--sumofsquares/solver/mosek.py30
-rw-r--r--sumofsquares/solver/scs.py134
2 files changed, 157 insertions, 7 deletions
diff --git a/sumofsquares/solver/mosek.py b/sumofsquares/solver/mosek.py
index 763555b..de4f92f 100644
--- a/sumofsquares/solver/mosek.py
+++ b/sumofsquares/solver/mosek.py
@@ -50,9 +50,37 @@ def solve_cone(prob: Problem, verbose: bool = False,
if not MOSEK_ENV:
raise RuntimeError("You forgot to call `sumofsquares.solvers.mosek.setup(license)`!")
+ if prob.constraints["z"]:
+ raise NotImplementedError
+
+ if prob.constraints["l"]:
+ raise NotImplementedError
+
+ if prob.constraints["b"]:
+ raise NotImplementedError
+
+ if prob.constraints["q"]:
+ raise NotImplementedError
+
+ if prob.constraints["s"]:
+ raise NotImplementedError
+
+ if prob.constraints["ep"]:
+ raise NotImplementedError
+
+ if prob.constraints["ep*"]:
+ raise NotImplementedError
+
+ if prob.constraints["p"]:
+ raise NotImplementedError
+
+ if prob.constraints["p*"]:
+ raise NotImplementedError
+
+
with MOSEK_ENV as env:
with env.Task() as task:
if verbose:
task.set_Stream(mosek.streamtype.log, _streamprinter)
- raise NotImplementedError
+ return {}, MOSEKInfo()
diff --git a/sumofsquares/solver/scs.py b/sumofsquares/solver/scs.py
index 4f17a24..e7ab6b8 100644
--- a/sumofsquares/solver/scs.py
+++ b/sumofsquares/solver/scs.py
@@ -4,12 +4,16 @@ Solve sumofsquares problems using SCS
from __future__ import annotations
import scs
+import math
+import numpy as np
from collections import UserDict
from numpy.typing import NDArray
+from scipy.sparse import csc_matrix
from typing import TYPE_CHECKING
-from ..abc import Problem, SolverInfo
+from ..abc import SolverInfo
+from ..error import SolverError
from ..variable import OptVariable
if TYPE_CHECKING:
@@ -26,19 +30,137 @@ class SCSInfo(SolverInfo, UserDict):
return self[key]
-def vectorize_matrix(m: NDArray) -> NDArray:
+def vec(m: NDArray) -> NDArray:
r"""
Vectorize a symmetric matrix for SCS by storing only a half of the entries
and multiplying the off-diaognal elements by :math:`\sqrt{2}`.
"""
- raise NotImplementedError
+ n, _ = m.shape
+ # Since it is symmetric we only work with the lower triangular part
+ scaling = (np.tril(n, k=-1) * np.sqrt(2) + np.eye(n))
+ v = (m * scaling)[np.tril_indices(n)].reshape((-1, 1))
+ assert v.shape[0] == n * (n + 1) // 2, "SCS Matrix vectorization is incorrect!"
+ return v
+
+
+def mat(v: NDArray) -> NDArray:
+ r"""
+ Reconstruct a symmetric matrix from a vector that was created using
+ :py:fn:`vec`. This scales the off-diagonal entries by :math:`1/\sqrt{2}`.
+ """
+ n_float = .5 * (-1 + np.sqrt(1 + 8 * v.shape[0]))
+ n = int(n_float)
+ if not math.isclose(n, n_float):
+ raise ValueError("To construct a n x n symmetric matrix, the vector "
+ "must be an n * (n + 1) / 2 dimensional column.")
+
+ m = np.zeros((n, n))
+ m[np.tril_indices(n)] = v
+
+ scaling = (np.tril(n, k=-1) / np.sqrt(2) + np.eye(n))
+ m = m * scaling
+ m = m + m.T - np.diag(np.diag(m))
+ return m
def solve_cone(prob: ConicProblem, verbose: bool = False,
*args, **kwargs) -> tuple[dict[OptVariable, float], SCSInfo]:
r"""
- Solve a conic problem in the cone of SOS polynomials
- :math:`\mathbf{\Sigma}_d(x)` using SCS.
+ Any `*args` and `**kwargs` other than `prob` and `verbose` are passed
+ directly to the SCS solver call.
"""
+ # SCS solves problems that have the following (primal) form
+ #
+ # minimize .5 * x.T @ P @ x + q.T @ x
+ #
+ # subject to Ax + s = b
+ # s in K
+ #
+ # where K is a product of cones (in the following order):
+ #
+ # z zero cone
+ # l linear cone
+ # b box cone
+ # q second order cone
+ # s positive semidefinite cone
+ # ep exponential cone
+ # ep* dual exponential cone
+ # p power cone
+ # p* dual power cone
+
+ P, q = None, None
+ A_rows, b_rows = [], []
+
+ # for (linear, constant) in prob.constrains["z"]:
+
+ if prob.constraints["l"]:
+ raise NotImplementedError
+
+ if prob.constraints["b"]:
+ raise NotImplementedError
+
+ if prob.constraints["q"]:
+ raise NotImplementedError
+
+ if prob.constraints["s"]:
+ raise NotImplementedError
+
+ if prob.constraints["ep"]:
+ raise NotImplementedError
+
+ if prob.constraints["ep*"]:
+ raise NotImplementedError
+
+ if prob.constraints["p"]:
+ raise NotImplementedError
+
+ if prob.constraints["p*"]:
+ raise NotImplementedError
+
+ assert len(A_rows) > 0, "Problem is unconstrained! Something is very wrong"
+
+ A = np.hstack(A_rows)
+ b = np.hstack(b_rows)
+
+ data = {
+ "P": csc_matrix(P) if P is not None else None,
+ "q": q,
+ "A": csc_matrix(A), "b": b,
+ }
+
+ cone = {
+ "z": sum(prob.dims["z"]),
+ "l": sum(prob.dims["l"]),
+ "b": prob.dims["b"], # TODO: check
+ "q": prob.dims["q"],
+ "s": prob.dims["s"],
+ "ep": sum(prob.dims["e"]), # TODO: set directly to remove sum
+ "ed": sum(prob.dims["e*"]), # TODO: set directly to remove sum
+ "p": prob.dims["p"] + list(-psize for psize in prob.dims["p*"]),
+ }
+
+ try:
+ solver = scs.SCS(data, cone, *args, verbose=verbose, **kwargs)
+ # TODO: add mechanism to pass stuff for warm start, also store
+ # ScsSolver object somewhere for reuse?
+ sol = solver.solve()
+
+ except Exception as e:
+ raise SolverError("SCS can't solve this problem, "
+ "see previous exception for details on why.") from e
+
+ if sol['x'] is None:
+ return {}, SCSInfo(sol["info"])
+
+ results, i = {}, 0
+ for variable in prob.variables:
+ num_indices = math.prod(variable.shape)
+ values = np.array(sol["x"][i:i+num_indices]).reshape(variable.shape)
+ if values.shape == (1, 1):
+ values = values[0, 0]
+
+ results[variable] = values
+ i += num_indices
+
+ return results, SCSInfo(sol["info"])
- raise NotImplementedError