summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-06-06 00:25:33 +0200
committerNao Pross <np@0hm.ch>2024-06-06 00:25:33 +0200
commite8bd49a7b03d932bcbf87826028a9498a6bc2b67 (patch)
treeead3a4ce8e283fb650ba413a3939e0ca4c2e9bba
parentImplement semidefinite constraints in SCS (diff)
downloadsumofsquares-e8bd49a7b03d932bcbf87826028a9498a6bc2b67.tar.gz
sumofsquares-e8bd49a7b03d932bcbf87826028a9498a6bc2b67.zip
Delete class OptVariableExprMixin made obsolete by new Symbol in polymatrix
-rw-r--r--sumofsquares/problems.py6
-rw-r--r--sumofsquares/variable.py86
2 files changed, 16 insertions, 76 deletions
diff --git a/sumofsquares/problems.py b/sumofsquares/problems.py
index 28327a1..8f904c5 100644
--- a/sumofsquares/problems.py
+++ b/sumofsquares/problems.py
@@ -16,7 +16,7 @@ from typing import Any, Sequence
from typing_extensions import override
from polymatrix.expression.expression import Expression, VariableExpression
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.variableexprmixin import VariableExprMixin
from polymatrix.expression.init import init_variable_expr
from polymatrix.expressionstate import ExpressionState
from polymatrix.polymatrix.mixins import PolyMatrixMixin
@@ -28,7 +28,7 @@ from .constraints import NonNegative, EqualToZero, PositiveSemiDefinite, Exponen
from .solver.cvxopt import solve_cone as cvxopt_solve_cone
from .solver.scs import solve_cone as scs_solve_cone
from .utils import partition
-from .variable import OptVariableExprMixin, OptSymbol
+from .variable import OptSymbol
# ┏━╸┏━┓┏┓╻╻┏━╸ ┏━┓┏━┓┏━┓┏┓ ╻ ┏━╸┏┳┓
@@ -119,7 +119,7 @@ class ConicResult(Result):
@override
def value_of(self, var: OptSymbol| VariableExpression) -> float:
if isinstance(var, VariableExpression):
- if not isinstance(var.underlying, OptVariableExprMixin):
+ if not isinstance(var.underlying, VariableExprMixin):
# TODO: error message
raise ValueError
diff --git a/sumofsquares/variable.py b/sumofsquares/variable.py
index 59d771b..0e87e74 100644
--- a/sumofsquares/variable.py
+++ b/sumofsquares/variable.py
@@ -10,19 +10,11 @@ this extends the following parts of polymatrix:
from __future__ import annotations
-from abc import abstractmethod
-from itertools import product
from typing import Iterable
-from typing_extensions import override
-from dataclasses import replace
-from dataclassabc import dataclassabc
from polymatrix.expression.expression import Expression, VariableExpression, init_variable_expression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.expressionstate import ExpressionState
-from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
-from polymatrix.polymatrix.init import init_poly_matrix
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.expression.init import init_variable_expr
from polymatrix.symbol import Symbol
@@ -30,74 +22,22 @@ class OptSymbol(Symbol):
""" Symbol for an optimization (decision) variable. """
-class OptVariableExprMixin(ExpressionBaseMixin):
- """ Optimization (decision) variable mixin for expression object. """
-
- @override
- @property
- @abstractmethod
- def shape(self) -> tuple[int, int] | ExpressionBaseMixin:
- """ Shape of the optimization variable expression. """
-
- @property
- @abstractmethod
- def symbol(self) -> OptSymbol:
- """ Symbol of the optimization variable. """
-
- # @override
- def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
- if isinstance(self.shape, ExpressionBaseMixin):
- state, shape_pm = self.shape.apply(state)
- if shape_pm.shape != (2, 1):
- raise ValueError("If shape is an expression it must evaluate to a 2d row vector, "
- f"but here it has shape {shape_pm.shape}")
-
- # FIXME: should check that they are actually integers
- nrows = int(shape_pm.at(0, 0).constant())
- ncols = int(shape_pm.at(1, 0).constant())
-
- # Replace shape field with computed shape
- state = state.register(self.symbol, shape=(nrows, ncols))
-
- elif isinstance(self.shape, tuple):
- nrows, ncols = self.shape # for for loop below
- state = state.register(self.symbol, self.shape)
-
- else:
- raise ValueError("Shape must be a tuple or expression that "
- f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}")
-
- indices = state.get_indices(self.symbol)
-
- p = PolyMatrixDict()
- for (row, col), index in zip(product(range(nrows), range(ncols)), indices):
- p[row, col] = PolyDict({
- # Create monomial with variable to the first power
- # with coefficient of one
- MonomialIndex((VariableIndex(index, power=1),)): 1.
- })
-
- return state, init_poly_matrix(p, shape=(nrows, ncols))
-
-
-@dataclassabc(frozen=True)
-class OptVariableExprImpl(OptVariableExprMixin):
- symbol: OptSymbol
- shape: str
-
- def __str__(self):
- return self.symbol
-
-
-def init_opt_variable_expr(variable, shape):
- return OptVariableExprImpl(variable, shape)
-
-
def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> VariableExpression:
""" Construct an optimization variable. """
if isinstance(shape, Expression):
shape = shape.underlying
- return init_variable_expression(underlying=init_opt_variable_expr(OptSymbol(name), shape))
+
+ elif isinstance(shape, tuple):
+ nrows, ncols = shape
+ if isinstance(nrows, Expression):
+ nrows = nrows.underlying
+
+ if isinstance(ncols, Expression):
+ ncols = ncols.underlying
+
+ shape = (nrows, ncols)
+
+ return init_variable_expression(underlying=init_variable_expr(OptSymbol(name), shape))
def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> Iterable[VariableExpression]: