diff options
Diffstat (limited to '')
-rw-r--r-- | sumofsquares/expression.py | 137 |
1 files changed, 0 insertions, 137 deletions
diff --git a/sumofsquares/expression.py b/sumofsquares/expression.py deleted file mode 100644 index 28f1794..0000000 --- a/sumofsquares/expression.py +++ /dev/null @@ -1,137 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import Iterable -from typing_extensions import override -from dataclassabc import dataclassabc - -from polymatrix.expression.expression import ExpressionBaseMixin -from polymatrix.expression.from_ import from_statemonad -from polymatrix.expressionstate import ExpressionState -from polymatrix.statemonad import init_state_monad -from polymatrix.polymatrix.mixins import PolyMatrixMixin - -from polymatrix.expression.init import ( - init_concatenate_expr, - init_diag_expr, - init_lower_triangular_expr, - init_transpose_expr, - init_slice_expr) - -from .abc import Constraint -from .constraints import PositiveSemidefinite, ExponentialCone -from .error import SolverError -from .variable import init_opt_variable_expr - - -class SOSExpressionBaseMixin(ExpressionBaseMixin): - @override - def apply(self, state: ExpressionState) -> tuple[ExpressionState, ExpressionBaseMixin]: - raise SolverError(f"Expression containing {self.__class__.__qualname__} " - "cannot be used directly, they need to be rewritten " - "into an equivalent form using a canonicalization function.") - - @abstractmethod - def recast(self) -> tuple[ExpressionBaseMixin, Iterable[Constraint]]: - """ - Recast the expression into a form that can be directly used for - optimization, possibly by introducing new constraints. - - The return values are the new expression to be minimized and the new - constraints that have to be added. - """ - - -class LogDetMixin(SOSExpressionBaseMixin): - """ Compute the sum of the logarithm of the eigenvalues. """ - - @property - @abstractmethod - def underlying(self) -> ExpressionBaseMixin: - """" Take the logdet of this expression """ - - @override - def recast(self) -> tuple[ExpressionBaseMixin, Iterable[Constraint]]: - # The problem - # - # maximize logdet(A) - # - # is equivalent to solving - # - # maximize t - # - # subject to [ A Z ] - # [ Z.T diag(Z) ] >= 0 - # - # Z lower triangular - # t <= sum_i log(Z[i,i]) - # - # and the last constraint of the above is equivalent to - # - # t <= sum_i u[i] - # u_i <= log(Z[i, i]) for all i - # - # And finally to get rid of the log the latter constraint one is - # equivalent to - # - # (Z[i,i], 1, u[i]) in Exponential Cone for all i - # - # Hence we can replace the original problem with - # - # minimize - sum_i u[i] - # - # subject to [ A Z ] - # [ Z.T diag(Z) ] >= 0 - # - # Z lower triangular - # (Z[i,i], 1, u[i]) in ExpCone for all i - - A = self.underlying - - # FIXME: get rid of these functions, create ShapeExprMixin in polymatrix? - def make_u(state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: - state, pm = A.apply(state) - n, m = pm.shape - if n != m: - raise ValueError(f"Matrix A of logdet(A) must be square, " - f"but it has shape {pm.shape}") - - # FIXME: should check to avoid name clashes - u = init_opt_variable_expr("u_logdet", shape=(n, 1)) - return u.apply(state) - - def make_z(state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: - state, pm = A.apply(state) - n, m = pm.shape - if n != m: - raise ValueError(f"Matrix A of logdet(A) must be square, " - f"but it has shape {pm.shape}") - - # FIXME: should check to avoid name clashes - Z = init_lower_triangular_expr(init_opt_variable_expr("Z_logdet", shape=(n * (n + 1) // 2, 1))) - return Z.apply(state) - - Z = from_statemonad(init_state_monad(make_z)) - Z_T = init_transpose_expr(Z) - Z_diag = init_diag_expr(Z) - - u = from_statemonad(init_state_monad(make_u)) - - # we call the new big matrix Q - Q = init_concatenate_expr(((A, Z), (Z_T, Z_diag))) - - def make_expcones(state: ExpressionState) -> tuple[ExpressionState, tuple[Constraint]]: - pass - - - constraints = [PositiveSemidefinite(Q)] - - raise NotImplementedError - - -@dataclassabc(froze=True) -class LogDetImpl(LogDetMixin): - underlying: ExpressionBaseMixin - - def __str__(self): - return f"logdet({self.underlying})" |