diff options
Diffstat (limited to '')
-rw-r--r-- | sumofsquares/__init__.py | 2 | ||||
-rw-r--r-- | sumofsquares/optvariable.py | 80 |
2 files changed, 82 insertions, 0 deletions
diff --git a/sumofsquares/__init__.py b/sumofsquares/__init__.py index 0aa52d6..0bc6e7d 100644 --- a/sumofsquares/__init__.py +++ b/sumofsquares/__init__.py @@ -1,3 +1,5 @@ from sumofsquares.sosexpr.abc import ParamSOSExpr, SOSExpr from sumofsquares.sosexpr.init import init_sos_expr, init_param_expr, init_param_expr_from_reference, init_putinar_epsilon from sumofsquares.cvxopt import solve_cone, solve_cone2, solve_sos_problem, solve_sos_problem2 + +from sumofsquares.optvariable import from_names diff --git a/sumofsquares/optvariable.py b/sumofsquares/optvariable.py new file mode 100644 index 0000000..cf4e99b --- /dev/null +++ b/sumofsquares/optvariable.py @@ -0,0 +1,80 @@ +from itertools import product +from typing_extensions import override +from dataclasses import replace +from dataclassabc import dataclassabc + +from polymatrix.expression.expression import Expression, init_expression +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex +from polymatrix.variable.abc import Variable + + +class OptVariableMixin(ExpressionBaseMixin, Variable): + """ Optimization (decision) variable. """ + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: + state, indices = state.index(self) + p = PolyMatrixDict() + + rows, cols = self.shape + for (row, col), index in zip(product(range(rows), range(cols)), indices): + p[row, col] = PolyDict({ + # Create monomial with variable to the first power + # with coefficient of one + MonomialIndex((VariableIndex(index, power=1),)): 1. + }) + + return state, init_poly_matrix(p, self.shape) + + +@dataclassabc(frozen=True) +class OptVariableImpl(OptVariableMixin): + name: str + shape: tuple[int, int] + + +def init_opt_variable_expr(name, shape): + return OptVariableImpl(name, shape) + + +class OptVariableExpression(Expression, Variable): + """ + Expression that is an optimization (decision) variable, i.e. an expression + that cannot be reduced further. + """ + @override + @property + def name(self): + return self.underlyng.name + + @override + @property + def shape(self): + return self.underlyng.shape + + +@dataclassabc(frozen=True) +class OptVariableExpressionImpl(OptVariableExpression): + underlying: ExpressionBaseMixin + + def copy(self, underlying: ExpressionBaseMixin) -> Expression: + return init_expression(underlying) + + +def init_opt_variable_expression(underlying: OptVariableMixin) -> OptVariableExpression: + return OptVariableExpressionImpl(underlying=underlying) + + +def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[OptVariableExpression] | OptVariableExpression: + """ Construct one or multiple variables from comma separated a list of names. """ + variables = tuple(init_opt_variable_expression( + underlying=init_opt_variable_expr(name.strip(), shape)) + for name in names.split(",")) + + if len(variables) == 1: + return variables[0] + + return variables |