diff options
-rw-r--r-- | sumofsquares/problems.py | 6 | ||||
-rw-r--r-- | sumofsquares/variable.py | 86 |
2 files changed, 16 insertions, 76 deletions
diff --git a/sumofsquares/problems.py b/sumofsquares/problems.py index 28327a1..8f904c5 100644 --- a/sumofsquares/problems.py +++ b/sumofsquares/problems.py @@ -16,7 +16,7 @@ from typing import Any, Sequence from typing_extensions import override from polymatrix.expression.expression import Expression, VariableExpression -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.variableexprmixin import VariableExprMixin from polymatrix.expression.init import init_variable_expr from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin @@ -28,7 +28,7 @@ from .constraints import NonNegative, EqualToZero, PositiveSemiDefinite, Exponen from .solver.cvxopt import solve_cone as cvxopt_solve_cone from .solver.scs import solve_cone as scs_solve_cone from .utils import partition -from .variable import OptVariableExprMixin, OptSymbol +from .variable import OptSymbol # ┏━╸┏━┓┏┓╻╻┏━╸ ┏━┓┏━┓┏━┓┏┓ ╻ ┏━╸┏┳┓ @@ -119,7 +119,7 @@ class ConicResult(Result): @override def value_of(self, var: OptSymbol| VariableExpression) -> float: if isinstance(var, VariableExpression): - if not isinstance(var.underlying, OptVariableExprMixin): + if not isinstance(var.underlying, VariableExprMixin): # TODO: error message raise ValueError diff --git a/sumofsquares/variable.py b/sumofsquares/variable.py index 59d771b..0e87e74 100644 --- a/sumofsquares/variable.py +++ b/sumofsquares/variable.py @@ -10,19 +10,11 @@ this extends the following parts of polymatrix: from __future__ import annotations -from abc import abstractmethod -from itertools import product from typing import Iterable -from typing_extensions import override -from dataclasses import replace -from dataclassabc import dataclassabc from polymatrix.expression.expression import Expression, VariableExpression, init_variable_expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate import ExpressionState -from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex -from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.expression.init import init_variable_expr from polymatrix.symbol import Symbol @@ -30,74 +22,22 @@ class OptSymbol(Symbol): """ Symbol for an optimization (decision) variable. """ -class OptVariableExprMixin(ExpressionBaseMixin): - """ Optimization (decision) variable mixin for expression object. """ - - @override - @property - @abstractmethod - def shape(self) -> tuple[int, int] | ExpressionBaseMixin: - """ Shape of the optimization variable expression. """ - - @property - @abstractmethod - def symbol(self) -> OptSymbol: - """ Symbol of the optimization variable. """ - - # @override - def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: - if isinstance(self.shape, ExpressionBaseMixin): - state, shape_pm = self.shape.apply(state) - if shape_pm.shape != (2, 1): - raise ValueError("If shape is an expression it must evaluate to a 2d row vector, " - f"but here it has shape {shape_pm.shape}") - - # FIXME: should check that they are actually integers - nrows = int(shape_pm.at(0, 0).constant()) - ncols = int(shape_pm.at(1, 0).constant()) - - # Replace shape field with computed shape - state = state.register(self.symbol, shape=(nrows, ncols)) - - elif isinstance(self.shape, tuple): - nrows, ncols = self.shape # for for loop below - state = state.register(self.symbol, self.shape) - - else: - raise ValueError("Shape must be a tuple or expression that " - f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}") - - indices = state.get_indices(self.symbol) - - p = PolyMatrixDict() - for (row, col), index in zip(product(range(nrows), range(ncols)), indices): - p[row, col] = PolyDict({ - # Create monomial with variable to the first power - # with coefficient of one - MonomialIndex((VariableIndex(index, power=1),)): 1. - }) - - return state, init_poly_matrix(p, shape=(nrows, ncols)) - - -@dataclassabc(frozen=True) -class OptVariableExprImpl(OptVariableExprMixin): - symbol: OptSymbol - shape: str - - def __str__(self): - return self.symbol - - -def init_opt_variable_expr(variable, shape): - return OptVariableExprImpl(variable, shape) - - def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> VariableExpression: """ Construct an optimization variable. """ if isinstance(shape, Expression): shape = shape.underlying - return init_variable_expression(underlying=init_opt_variable_expr(OptSymbol(name), shape)) + + elif isinstance(shape, tuple): + nrows, ncols = shape + if isinstance(nrows, Expression): + nrows = nrows.underlying + + if isinstance(ncols, Expression): + ncols = ncols.underlying + + shape = (nrows, ncols) + + return init_variable_expression(underlying=init_variable_expr(OptSymbol(name), shape)) def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> Iterable[VariableExpression]: |