summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-24 21:19:55 +0200
committerNao Pross <np@0hm.ch>2024-05-24 21:19:55 +0200
commit2ce3633b355a9419a64ee05c2acd53540e176c6f (patch)
tree909bf6901ede6af3ffab545366b7128e079d4f3f
parentUpdate class diagram in __init__ (diff)
downloadsumofsquares-2ce3633b355a9419a64ee05c2acd53540e176c6f.tar.gz
sumofsquares-2ce3633b355a9419a64ee05c2acd53540e176c6f.zip
Allow shape of OptVariableExpr to be an Expression
-rw-r--r--sumofsquares/__init__.py9
-rw-r--r--sumofsquares/variable.py53
2 files changed, 49 insertions, 13 deletions
diff --git a/sumofsquares/__init__.py b/sumofsquares/__init__.py
index 54aec7e..6ca2b39 100644
--- a/sumofsquares/__init__.py
+++ b/sumofsquares/__init__.py
@@ -75,9 +75,14 @@ from .abc import Problem, Result, Set, Constraint, Solver
from .constraints import BasicSemialgebraicSet, NonNegative
from .problems import PutinarSOSProblem
from .utils import partition
-from .variable import OptVariable, from_names as internal_from_names
+from .variable import (
+ OptVariable,
+ from_name as internal_from_name,
+ from_names as internal_from_names)
-# export function
+
+# export optimization variable constructors
+from_name = internal_from_name
from_names = internal_from_names
def make_sos_constraint(expr: Expression, domain: Set | None = None) -> NonNegative:
diff --git a/sumofsquares/variable.py b/sumofsquares/variable.py
index f309406..28bbc97 100644
--- a/sumofsquares/variable.py
+++ b/sumofsquares/variable.py
@@ -10,11 +10,14 @@ this extends the following parts of polymatrix:
from __future__ import annotations
+from abc import abstractmethod
from itertools import product
+from typing import Iterable
from typing_extensions import override
+from dataclasses import replace
from dataclassabc import dataclassabc
-from polymatrix.expression.expression import VariableExpression, init_variable_expression
+from polymatrix.expression.expression import Expression, VariableExpression, init_variable_expression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate import ExpressionState
from polymatrix.polymatrix.abc import PolyMatrix
@@ -31,26 +34,52 @@ class OptVariableMixin(ExpressionBaseMixin, OptVariable):
""" Optimization (decision) variable mixin for expression object. """
@override
- def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
- state = state.register(self)
- indices = state.get_indices(self)
- p = PolyMatrixDict()
+ @property
+ @abstractmethod
+ def shape(self) -> tuple[int, int] | ExpressionBaseMixin:
+ """ Shape of the optimization variable expression. """
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
+ if isinstance(self.shape, ExpressionBaseMixin):
+ state, shape_pm = self.shape.apply(state)
+ if shape_pm.shape != (2, 1):
+ raise ValueError("If shape is an expression it must evaluate to a 2d row vector, "
+ f"but here it has shape {shape_pm.shape}")
+
+ # FIXME: should check that they are actually integers
+ nrows = int(shape_pm.at(0, 0).constant())
+ ncols = int(shape_pm.at(1, 0).constant())
+
+ # Replace shape field with computed shape
+ v = replace(self, shape=(nrows, ncols))
+ state = state.register(v)
+ indices = state.get_indices(v)
+
+ elif isinstance(self.shape, tuple):
+ nrows, ncols = self.shape
+ state = state.register(self)
+ indices = state.get_indices(self)
+
+ else:
+ raise ValueError("Shape must be a tuple or expression that "
+ f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}")
- rows, cols = self.shape
- for (row, col), index in zip(product(range(rows), range(cols)), indices):
+ p = PolyMatrixDict()
+ for (row, col), index in zip(product(range(nrows), range(ncols)), indices):
p[row, col] = PolyDict({
# Create monomial with variable to the first power
# with coefficient of one
MonomialIndex((VariableIndex(index, power=1),)): 1.
})
- return state, init_poly_matrix(p, self.shape)
+ return state, init_poly_matrix(p, shape=(nrows, ncols))
@dataclassabc(frozen=True)
class OptVariableImpl(OptVariableMixin):
name: str
- shape: tuple[int, int]
+ shape: tuple[int, int] | ExpressionBaseMixin
def __str__(self):
return self.name
@@ -60,12 +89,14 @@ def init_opt_variable_expr(name, shape):
return OptVariableImpl(name, shape)
-def from_name(name: str, shape: tuple[int, int] = (1, 1)) -> VariableExpression:
+def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> VariableExpression:
""" Construct an optimization variable. """
+ if isinstance(shape, Expression):
+ shape = shape.underlying
return init_variable_expression(underlying=init_opt_variable_expr(name, shape))
-def from_names(names: str, shape: tuple[int, int] = (1, 1)) -> Iterable[VariableExpression]:
+def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1, 1)) -> Iterable[VariableExpression]:
""" Construct one or multiple variables from comma separated a list of names. """
for name in names.split(","):
yield from_name(name.strip(), shape)