From a25c5d60e0874a8387aae31e5aa5ffd02bcdbac0 Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Tue, 4 Jun 2024 17:53:44 +0200
Subject: Update variables to use symbols, see polymatrix dependency

commit 5c5268d2adfa3dfb6fb1426ac3d59d08c9e36d2b
---
 sumofsquares/abc.py           |  4 +--
 sumofsquares/problems.py      | 63 ++++++++++++++++++++++---------------------
 sumofsquares/solver/cvxopt.py | 10 +++----
 sumofsquares/solver/mosek.py  |  4 +--
 sumofsquares/solver/scs.py    | 10 +++----
 sumofsquares/variable.py      | 40 ++++++++++++++-------------
 6 files changed, 69 insertions(+), 62 deletions(-)

diff --git a/sumofsquares/abc.py b/sumofsquares/abc.py
index 20f4bd6..2f852df 100644
--- a/sumofsquares/abc.py
+++ b/sumofsquares/abc.py
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
 from enum import Enum, auto
 from typing import Any, Generic, TypeVar
 
-from sumofsquares.variable import OptVariable
+from sumofsquares.variable import OptSymbol
 
 
 # ┏━┓┏━┓╻  ╻ ╻┏━╸┏━┓
@@ -52,7 +52,7 @@ class Constraint(ABC, Generic[E]):
 class Result(ABC):
     """ Result of an optimization problem. """
     @abstractmethod
-    def value_of(self, var: OptVariable) -> float:
+    def value_of(self, var: OptSymbol) -> float:
         """ Retrieve value of variable. """
 
 
diff --git a/sumofsquares/problems.py b/sumofsquares/problems.py
index f4f3d7b..7bcd629 100644
--- a/sumofsquares/problems.py
+++ b/sumofsquares/problems.py
@@ -17,17 +17,18 @@ from typing_extensions import override
 
 from polymatrix.expression.expression import Expression, VariableExpression
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.init import init_variable_expr
 from polymatrix.expressionstate import ExpressionState
 from polymatrix.polymatrix.mixins import PolyMatrixMixin
 from polymatrix.polymatrix.index import MonomialIndex, VariableIndex
-from polymatrix.variable import Variable
+from polymatrix.symbol import Symbol
 
 from .abc import Problem, Constraint, Solver, Result
 from .constraints import NonNegative, EqualToZero, PositiveSemiDefinite, ExponentialCone
 from .solver.cvxopt import solve_cone as cvxopt_solve_cone
 from .solver.scs import solve_cone as scs_solve_cone
 from .utils import partition
-from .variable import OptVariable
+from .variable import OptVariableExprMixin, OptSymbol
 
 
 # ┏━╸┏━┓┏┓╻╻┏━╸   ┏━┓┏━┓┏━┓┏┓ ╻  ┏━╸┏┳┓
@@ -89,7 +90,7 @@ class ConicProblem(Problem):
     """
 
     solver: Solver
-    variables: Sequence[OptVariable]
+    variables: dict[OptSymbol, tuple[int, int]]
 
     @property
     @override
@@ -112,32 +113,24 @@ class ConicProblem(Problem):
 @dataclassabc(frozen=True)
 class ConicResult(Result):
     """ Result of a Conic Problem """
-    values: dict[OptVariable, float]
+    values: dict[OptSymbol, float]
     solver_info: Any
 
     @override
-    def value_of(self, var: OptVariable | VariableExpression) -> float:
+    def value_of(self, var: OptSymbol| VariableExpression) -> float:
         if isinstance(var, VariableExpression):
-            if not isinstance(var.underlying, OptVariable):
+            if not isinstance(var.underlying, OptVariableExprMixin):
                 # TODO: error message
                 raise ValueError
             
             # Unwrap the expression
-            var = var.underlying
+            symbol = var.underlying.symbol
 
-        if var not in self.values:
-            # FIXME: this is a temporary fix here.
-            if isinstance(var.shape, ExpressionBaseMixin):
-                state = poly.make_state()
-                state, shapepm = var.shape.apply(state)
-                shape = (shapepm.at(0,0).constant(), shapepm.at(1,0).constant())
-                var = replace(var, shape=shape)
+            if var not in self.values:
+                raise KeyError(f"There is no result for the variable {var}. "
+                               f"Was the problem successfully solved?")
 
-                if var not in self.values:
-                    raise KeyError(f"There is no result for the variable {var}. "
-                                   f"Was the problem successfully solved?")
-
-        return self.values[var]
+        return self.values[symbol]
 
 
 # ┏━┓╻ ╻┏┳┓   ┏━┓┏━╸   ┏━┓┏━┓╻ ╻┏━┓┏━┓┏━╸┏━┓   ┏━┓┏━┓┏━┓┏━╸┏━┓┏━┓┏┳┓
@@ -201,17 +194,19 @@ class SOSProblem(Problem):
             state, pm  = c.expression.apply(state)
             variable_indices.update(pm.variables())
 
-        variables = set(state.get_variable_from_variable_index(v)
+        variables = set(state.get_symbol_from_variable_index(v)
                         for v in variable_indices)
 
-        # Collect variables
-        def is_optvariable(v):
-            return isinstance(v, OptVariable)
+        # Collect optimization variables
+        def is_opt(v):
+            return isinstance(v, OptSymbol)
 
-        polynomial_variables, variables = partition(is_optvariable, variables)
+        polynomial_variables, variables = partition(is_opt, variables)
         polynomial_variables = tuple(polynomial_variables) # because it is a generator
 
-        x = poly.v_stack((1,) + polynomial_variables)
+        x = poly.v_stack((1,) + tuple(
+            init_variable_expr(v, state.get_shape(v))
+            for v in polynomial_variables))
         for i, c in enumerate(self.constraints):
             if isinstance(c, EqualToZero):
                 state, deg = c.expression.degree().apply(state)
@@ -309,14 +304,14 @@ class InternalSOSProblem(Problem):
     """
     cost: PolyMatrixMixin
     constraints: Sequence[Constraint[PolyMatrixMixin]]
-    variables: Sequence[OptVariable]
-    polynomial_variables: Sequence[Variable]
+    variables: Sequence[OptSymbol]
+    polynomial_variables: Sequence[Symbol]
     solver: Solver
 
     # TODO: remove state field from this class, it is redundant
     state: ExpressionState 
 
-    def to_conic_problem(self) -> ConicProblem:
+    def to_conic_problem(self, verbose: bool = False) -> ConicProblem:
         """
         Conver the SOS problem into a Conic program. 
         """
@@ -410,12 +405,20 @@ class InternalSOSProblem(Problem):
         if all(len(cl) == 0 for cl in constraints.values()):
             raise ValueError("Optimization problem is unconstrained!")
 
+        if verbose:
+            # print("Conic problem has shapes: \n"
+            #       f"\t {q.shape = }\n")
+            pass
+
+
         return ConicProblem(P=P, q=q, constraints=constraints,
                             dims=dims, is_qp=is_qp,
                             solver=self.solver,
-                            variables=self.variables)
+                            variables={v : self.state.get_shape(v)
+                                for v in self.variables
+                            })
 
 
     @override
     def solve(self, verbose: bool = False) -> Result:
-        return self.to_conic_problem().solve(verbose)
+        return self.to_conic_problem(verbose).solve(verbose)
diff --git a/sumofsquares/solver/cvxopt.py b/sumofsquares/solver/cvxopt.py
index 1c37fbe..0692787 100644
--- a/sumofsquares/solver/cvxopt.py
+++ b/sumofsquares/solver/cvxopt.py
@@ -14,7 +14,7 @@ from pprint import pprint
 
 from ..abc import SolverInfo
 from ..error import SolverError, NotSupportedBySolver
-from ..variable import OptVariable
+from ..variable import OptSymbol
 
 if TYPE_CHECKING:
     from ..problems import ConicProblem
@@ -42,7 +42,7 @@ def vectorize_matrix(m: NDArray) -> NDArray:
 
 
 def solve_cone(prob: ConicProblem, verbose: bool = False,
-               *args, **kwargs) -> tuple[dict[OptVariable, NDArray | float], CVXOPTInfo]:
+               *args, **kwargs) -> tuple[dict[OptSymbol, NDArray | float], CVXOPTInfo]:
     r"""
     Any `*args` and `**kwargs` other than `prob` and `vebose` are passed to the
     CVXOPT solver.
@@ -150,9 +150,9 @@ def solve_cone(prob: ConicProblem, verbose: bool = False,
         return {}, CVXOPTInfo(info)
 
     results, i = {}, 0
-    for variable in prob.variables:
-        num_indices = math.prod(variable.shape)
-        values = np.array(info["x"][i:i+num_indices]).reshape(variable.shape)
+    for variable, shape in prob.variables.items():
+        num_indices = math.prod(shape)
+        values = np.array(info["x"][i:i+num_indices]).reshape(shape)
         if values.shape == (1, 1):
             values = values[0, 0]
 
diff --git a/sumofsquares/solver/mosek.py b/sumofsquares/solver/mosek.py
index de4f92f..6a8ce25 100644
--- a/sumofsquares/solver/mosek.py
+++ b/sumofsquares/solver/mosek.py
@@ -9,7 +9,7 @@ import mosek
 from pathlib import Path
 
 from ..abc import Problem, SolverInfo
-from ..variable import OptVariable
+from ..variable import OptSymbol
 
 
 class MOSEKInfo(SolverInfo):
@@ -40,7 +40,7 @@ def setup(license_file: Path | str | None = None):
 
 
 def solve_cone(prob: Problem, verbose: bool = False,
-               *args, **kwargs) -> tuple[dict[OptVariable, float], MOSEKInfo]:
+               *args, **kwargs) -> tuple[dict[OptSymbol, float], MOSEKInfo]:
     r"""
     Solve a conic problem in the cone of SOS polynomials
     :math:`\mathbf{\Sigma}_d(x)` using MOSEK.
diff --git a/sumofsquares/solver/scs.py b/sumofsquares/solver/scs.py
index e7ab6b8..a476700 100644
--- a/sumofsquares/solver/scs.py
+++ b/sumofsquares/solver/scs.py
@@ -14,7 +14,7 @@ from typing import TYPE_CHECKING
 
 from ..abc import SolverInfo
 from ..error import SolverError
-from ..variable import OptVariable
+from ..variable import OptSymbol
 
 if TYPE_CHECKING:
     from ..problems import ConicProblem
@@ -64,7 +64,7 @@ def mat(v: NDArray) -> NDArray:
 
 
 def solve_cone(prob: ConicProblem, verbose: bool = False,
-               *args, **kwargs) -> tuple[dict[OptVariable, float], SCSInfo]:
+               *args, **kwargs) -> tuple[dict[OptSymbol, float], SCSInfo]:
     r"""
     Any `*args` and `**kwargs` other than `prob` and `verbose` are passed
     directly to the SCS solver call.
@@ -153,9 +153,9 @@ def solve_cone(prob: ConicProblem, verbose: bool = False,
         return {}, SCSInfo(sol["info"])
 
     results, i = {}, 0
-    for variable in prob.variables:
-        num_indices = math.prod(variable.shape)
-        values = np.array(sol["x"][i:i+num_indices]).reshape(variable.shape)
+    for variable, shape in prob.variables.items():
+        num_indices = math.prod(shape)
+        values = np.array(sol["x"][i:i+num_indices]).reshape(shape)
         if values.shape == (1, 1):
             values = values[0, 0]
 
diff --git a/sumofsquares/variable.py b/sumofsquares/variable.py
index 2989732..59d771b 100644
--- a/sumofsquares/variable.py
+++ b/sumofsquares/variable.py
@@ -23,14 +23,14 @@ from polymatrix.expressionstate import ExpressionState
 from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
 from polymatrix.polymatrix.init import init_poly_matrix
 from polymatrix.polymatrix.mixins import PolyMatrixMixin
-from polymatrix.variable import Variable
+from polymatrix.symbol import Symbol
 
 
-class OptVariable(Variable):
-    """ Optimization (decision) variable. """
+class OptSymbol(Symbol):
+    """ Symbol for an optimization (decision) variable. """
 
 
-class OptVariableMixin(ExpressionBaseMixin, OptVariable):
+class OptVariableExprMixin(ExpressionBaseMixin):
     """ Optimization (decision) variable mixin for expression object. """
 
     @override
@@ -39,7 +39,12 @@ class OptVariableMixin(ExpressionBaseMixin, OptVariable):
     def shape(self) -> tuple[int, int] | ExpressionBaseMixin:
         """ Shape of the optimization variable expression. """
 
-    @override
+    @property
+    @abstractmethod
+    def symbol(self) -> OptSymbol:
+        """ Symbol of the optimization variable. """
+
+    # @override
     def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
         if isinstance(self.shape, ExpressionBaseMixin):
             state, shape_pm = self.shape.apply(state)
@@ -52,19 +57,18 @@ class OptVariableMixin(ExpressionBaseMixin, OptVariable):
             ncols = int(shape_pm.at(1, 0).constant())
 
             # Replace shape field with computed shape
-            v = replace(self, shape=(nrows, ncols))
-            state = state.register(v) 
-            indices = state.get_indices(v)
+            state = state.register(self.symbol, shape=(nrows, ncols)) 
 
         elif isinstance(self.shape, tuple):
-            nrows, ncols = self.shape
-            state = state.register(self) 
-            indices = state.get_indices(self)
+            nrows, ncols = self.shape # for for loop below
+            state = state.register(self.symbol, self.shape) 
 
         else:
             raise ValueError("Shape must be a tuple or expression that "
                 f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}")
 
+        indices = state.get_indices(self.symbol)
+
         p = PolyMatrixDict()
         for (row, col), index in zip(product(range(nrows), range(ncols)), indices):
             p[row, col] = PolyDict({
@@ -77,23 +81,23 @@ class OptVariableMixin(ExpressionBaseMixin, OptVariable):
 
 
 @dataclassabc(frozen=True)
-class OptVariableImpl(OptVariableMixin):
-    name: str
-    shape: tuple[int, int] | ExpressionBaseMixin
+class OptVariableExprImpl(OptVariableExprMixin):
+    symbol: OptSymbol
+    shape: str
 
     def __str__(self):
-        return self.name
+        return self.symbol
 
 
-def init_opt_variable_expr(name, shape):
-    return OptVariableImpl(name, shape)
+def init_opt_variable_expr(variable, shape):
+    return OptVariableExprImpl(variable, shape)
 
 
 def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> VariableExpression:
     """ Construct an optimization variable. """
     if isinstance(shape, Expression):
         shape = shape.underlying
-    return init_variable_expression(underlying=init_opt_variable_expr(name, shape))
+    return init_variable_expression(underlying=init_opt_variable_expr(OptSymbol(name), shape))
 
 
 def from_names(names: str, shape: tuple[int, int]  | ExpressionBaseMixin = (1, 1)) -> Iterable[VariableExpression]:
-- 
cgit v1.2.1