summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--sumofsquares/__init__.py2
-rw-r--r--sumofsquares/optvariable.py80
2 files changed, 82 insertions, 0 deletions
diff --git a/sumofsquares/__init__.py b/sumofsquares/__init__.py
index 0aa52d6..0bc6e7d 100644
--- a/sumofsquares/__init__.py
+++ b/sumofsquares/__init__.py
@@ -1,3 +1,5 @@
from sumofsquares.sosexpr.abc import ParamSOSExpr, SOSExpr
from sumofsquares.sosexpr.init import init_sos_expr, init_param_expr, init_param_expr_from_reference, init_putinar_epsilon
from sumofsquares.cvxopt import solve_cone, solve_cone2, solve_sos_problem, solve_sos_problem2
+
+from sumofsquares.optvariable import from_names
diff --git a/sumofsquares/optvariable.py b/sumofsquares/optvariable.py
new file mode 100644
index 0000000..cf4e99b
--- /dev/null
+++ b/sumofsquares/optvariable.py
@@ -0,0 +1,80 @@
+from itertools import product
+from typing_extensions import override
+from dataclasses import replace
+from dataclassabc import dataclassabc
+
+from polymatrix.expression.expression import Expression, init_expression
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate.abc import ExpressionState
+from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
+from polymatrix.variable.abc import Variable
+
+
+class OptVariableMixin(ExpressionBaseMixin, Variable):
+ """ Optimization (decision) variable. """
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
+ state, indices = state.index(self)
+ p = PolyMatrixDict()
+
+ rows, cols = self.shape
+ for (row, col), index in zip(product(range(rows), range(cols)), indices):
+ p[row, col] = PolyDict({
+ # Create monomial with variable to the first power
+ # with coefficient of one
+ MonomialIndex((VariableIndex(index, power=1),)): 1.
+ })
+
+ return state, init_poly_matrix(p, self.shape)
+
+
+@dataclassabc(frozen=True)
+class OptVariableImpl(OptVariableMixin):
+ name: str
+ shape: tuple[int, int]
+
+
+def init_opt_variable_expr(name, shape):
+ return OptVariableImpl(name, shape)
+
+
+class OptVariableExpression(Expression, Variable):
+ """
+ Expression that is an optimization (decision) variable, i.e. an expression
+ that cannot be reduced further.
+ """
+ @override
+ @property
+ def name(self):
+ return self.underlyng.name
+
+ @override
+ @property
+ def shape(self):
+ return self.underlyng.shape
+
+
+@dataclassabc(frozen=True)
+class OptVariableExpressionImpl(OptVariableExpression):
+ underlying: ExpressionBaseMixin
+
+ def copy(self, underlying: ExpressionBaseMixin) -> Expression:
+ return init_expression(underlying)
+
+
+def init_opt_variable_expression(underlying: OptVariableMixin) -> OptVariableExpression:
+ return OptVariableExpressionImpl(underlying=underlying)
+
+
+def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[OptVariableExpression] | OptVariableExpression:
+ """ Construct one or multiple variables from comma separated a list of names. """
+ variables = tuple(init_opt_variable_expression(
+ underlying=init_opt_variable_expr(name.strip(), shape))
+ for name in names.split(","))
+
+ if len(variables) == 1:
+ return variables[0]
+
+ return variables