From e5bbaea2d2b0192a8a45284b342a9b96dd3fc70f Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Tue, 28 May 2024 12:12:17 +0200
Subject: Fix incorrect concatenation of constraints SOSProblem.apply

There is still something off, but this is a start, need to debug more
---
 sumofsquares/problems.py | 64 +++++++++++++++++++++++++++++++-----------------
 1 file changed, 42 insertions(+), 22 deletions(-)

diff --git a/sumofsquares/problems.py b/sumofsquares/problems.py
index c313419..bfb13c3 100644
--- a/sumofsquares/problems.py
+++ b/sumofsquares/problems.py
@@ -10,6 +10,7 @@ import numpy as np
 
 from dataclassabc import dataclassabc
 from dataclasses import replace
+from itertools import groupby
 from numpy.typing import NDArray
 from typing import Any, Sequence
 from typing_extensions import override
@@ -179,14 +180,14 @@ class SOSProblem(Problem):
         """
         Convert to internal SOS problem by applying state to the expressions.
 
-        **Technical Note:** The internal SOS problem may only constraints that
-        are linear in the optimization variables, hence, conversion of
+        **Technical Note:** The internal SOS problem may only have constraints
+        that are affine in the optimization variables, hence, conversion of
         polynomial equality / non-negativity constraints are done here.
         Likewise the cost function must also be reduced to quadratic expression
         here.
         """
 
-        constraints: list[Constraint[PolyMatrixMixin]] = []
+        constraints: list[Constraint] = []
         state, cost = self.cost.apply(state)
 
         # Compute the polymatrix of each constraint expression. Even though the
@@ -213,13 +214,12 @@ class SOSProblem(Problem):
                 # Polynomial equality must be converted into coefficient
                 # matching condition
                 if deg.scalar().constant() > 1:
-                    state, pm = c.expression.linear_in(x).apply(state)
-                    constraints.append(replace(c, expression=pm))
+                    cnew = c.expression.linear_in(x)
+                    constraints.append(replace(c, expression=cnew))
 
-                # A normal (linear) equality
+                # A normal (affine) equality
                 else:
-                    state, pm = c.expression.apply(state)
-                    constraints.append(replace(c, expression=pm))
+                    constraints.append(c)
 
             elif isinstance(c, NonNegative):
                 if c.domain:
@@ -234,38 +234,58 @@ class SOSProblem(Problem):
                 # constraint of SOS quadratic form
                 if deg.scalar().constant() > 1:
                     # TODO: it seems to work fine even without .symmetric(). Why?
-                    state, pm = c.expression.quadratic_in(x).symmetric().apply(state)
-                    constraints.append(PositiveSemiDefinite(pm))
+                    cnew = c.expression.quadratic_in(x).symmetric()
+                    constraints.append(PositiveSemiDefinite(cnew))
 
-                # A normal (linear) constraint
+                # A normal (affine) constraint
                 else:
-                    state, pm = c.expression.apply(state)
-                    constraints.append(replace(c, expression=pm))
+                    constraints.append(c)
 
             elif isinstance(c, PositiveSemiDefinite):
-                state, pm = c.expression.apply(state)
+                state, pm = c.expression.cache().apply(state)
                 nrows, ncols = pm.shape
                 if nrows != ncols:
                     raise ValueError(f"PSD constraint cannot contain non-square matrix of shape ({nrows, ncols})!")
 
                 # PSD constraint can be passed as-is
-                constraints.append(replace(c, expression=pm))
+                constraints.append(c)
 
             elif isinstance(c, ExponentialCone):
-                state, pm = c.expression.apply(state)
-                nrows, ncols = pm.shape
+                state, pm = c.expression.shape.apply(state)
 
-                if ncols != 3:
+                if pm.at(1, 0).constant() != 3:
                     raise ValueError("Conic constraint must be a row vector [x, y, z] ",
                                      "or for multiple constraints it must be an n x 3 "
                                     f"matrix! Given expression has wrong shape {pm.shape}.")
 
-                constraints.append(replace(c, expresssion=pm))
+                constraints.append(c)
 
             else:
                 raise NotImplementedError(f"Cannot process constraint of type {type(c)} (yet).")
 
-        return state, InternalSOSProblem(cost, tuple(constraints),
+        # Convert Expressions into PolyMatrix objects
+        # Concatenate constraints so that there is only a big constraint per cone.
+        pm_constraints: list[Constraint[PolyMatrixMixin]] = []
+
+        # TODO: can we get rid of for loop inside InternalSOSProblem.to_conic_problem?
+        for (ctype, group) in groupby(constraints, key=type):
+            if ctype in (EqualToZero, NonNegative):
+                state, pm = poly.v_stack((c.expression for c in group)).apply(state)
+                pm_constraints.append(ctype(pm))
+
+            elif ctype is PositiveSemiDefinite:
+                expressions = (c.expression for c in group)
+                state, pm = poly.block_diag(expressions).apply(state)
+                pm_constraints.append(ctype(pm))
+
+            elif ctype is ExponentialCone:
+                state, pm = poly.v_stack((c.expression for c in group)).apply(state)
+                pm_constraints.append(ctype(pm))
+
+            else:
+                raise NotImplementedError(f"Cannot process constraint of type {ctype} (yet).")
+
+        return state, InternalSOSProblem(cost, tuple(pm_constraints),
                                          tuple(variables), polynomial_variables,
                                          self.solver, state)
 
@@ -349,8 +369,8 @@ class InternalSOSProblem(Problem):
             nrows, ncols = constr.shape
 
             if constr.degree > 1:
-                # If this error occurs an it is not the user's fault, there is a bug in
-                # SOSProblem.apply
+                # If this error occurs an it is not the user's fault, there is
+                # a bug in SOSProblem.apply
                 raise ValueError("To convert to conic constraints must be linear or affine "
                                 f"but {str(c.expression)} has degree {constr.degree}.")
 
-- 
cgit v1.2.1