diff options
author | Nao Pross <np@0hm.ch> | 2024-05-24 21:19:55 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-24 21:19:55 +0200 |
commit | 2ce3633b355a9419a64ee05c2acd53540e176c6f (patch) | |
tree | 909bf6901ede6af3ffab545366b7128e079d4f3f | |
parent | Update class diagram in __init__ (diff) | |
download | sumofsquares-2ce3633b355a9419a64ee05c2acd53540e176c6f.tar.gz sumofsquares-2ce3633b355a9419a64ee05c2acd53540e176c6f.zip |
Allow shape of OptVariableExpr to be an Expression
-rw-r--r-- | sumofsquares/__init__.py | 9 | ||||
-rw-r--r-- | sumofsquares/variable.py | 53 |
2 files changed, 49 insertions, 13 deletions
diff --git a/sumofsquares/__init__.py b/sumofsquares/__init__.py index 54aec7e..6ca2b39 100644 --- a/sumofsquares/__init__.py +++ b/sumofsquares/__init__.py @@ -75,9 +75,14 @@ from .abc import Problem, Result, Set, Constraint, Solver from .constraints import BasicSemialgebraicSet, NonNegative from .problems import PutinarSOSProblem from .utils import partition -from .variable import OptVariable, from_names as internal_from_names +from .variable import ( + OptVariable, + from_name as internal_from_name, + from_names as internal_from_names) -# export function + +# export optimization variable constructors +from_name = internal_from_name from_names = internal_from_names def make_sos_constraint(expr: Expression, domain: Set | None = None) -> NonNegative: diff --git a/sumofsquares/variable.py b/sumofsquares/variable.py index f309406..28bbc97 100644 --- a/sumofsquares/variable.py +++ b/sumofsquares/variable.py @@ -10,11 +10,14 @@ this extends the following parts of polymatrix: from __future__ import annotations +from abc import abstractmethod from itertools import product +from typing import Iterable from typing_extensions import override +from dataclasses import replace from dataclassabc import dataclassabc -from polymatrix.expression.expression import VariableExpression, init_variable_expression +from polymatrix.expression.expression import Expression, VariableExpression, init_variable_expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.abc import PolyMatrix @@ -31,26 +34,52 @@ class OptVariableMixin(ExpressionBaseMixin, OptVariable): """ Optimization (decision) variable mixin for expression object. """ @override - def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: - state = state.register(self) - indices = state.get_indices(self) - p = PolyMatrixDict() + @property + @abstractmethod + def shape(self) -> tuple[int, int] | ExpressionBaseMixin: + """ Shape of the optimization variable expression. """ + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + if isinstance(self.shape, ExpressionBaseMixin): + state, shape_pm = self.shape.apply(state) + if shape_pm.shape != (2, 1): + raise ValueError("If shape is an expression it must evaluate to a 2d row vector, " + f"but here it has shape {shape_pm.shape}") + + # FIXME: should check that they are actually integers + nrows = int(shape_pm.at(0, 0).constant()) + ncols = int(shape_pm.at(1, 0).constant()) + + # Replace shape field with computed shape + v = replace(self, shape=(nrows, ncols)) + state = state.register(v) + indices = state.get_indices(v) + + elif isinstance(self.shape, tuple): + nrows, ncols = self.shape + state = state.register(self) + indices = state.get_indices(self) + + else: + raise ValueError("Shape must be a tuple or expression that " + f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}") - rows, cols = self.shape - for (row, col), index in zip(product(range(rows), range(cols)), indices): + p = PolyMatrixDict() + for (row, col), index in zip(product(range(nrows), range(ncols)), indices): p[row, col] = PolyDict({ # Create monomial with variable to the first power # with coefficient of one MonomialIndex((VariableIndex(index, power=1),)): 1. }) - return state, init_poly_matrix(p, self.shape) + return state, init_poly_matrix(p, shape=(nrows, ncols)) @dataclassabc(frozen=True) class OptVariableImpl(OptVariableMixin): name: str - shape: tuple[int, int] + shape: tuple[int, int] | ExpressionBaseMixin def __str__(self): return self.name @@ -60,12 +89,14 @@ def init_opt_variable_expr(name, shape): return OptVariableImpl(name, shape) -def from_name(name: str, shape: tuple[int, int] = (1, 1)) -> VariableExpression: +def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> VariableExpression: """ Construct an optimization variable. """ + if isinstance(shape, Expression): + shape = shape.underlying return init_variable_expression(underlying=init_opt_variable_expr(name, shape)) -def from_names(names: str, shape: tuple[int, int] = (1, 1)) -> Iterable[VariableExpression]: +def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> Iterable[VariableExpression]: """ Construct one or multiple variables from comma separated a list of names. """ for name in names.split(","): yield from_name(name.strip(), shape) |