summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--sumofsquares/problems.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/sumofsquares/problems.py b/sumofsquares/problems.py
index bfb13c3..f4f3d7b 100644
--- a/sumofsquares/problems.py
+++ b/sumofsquares/problems.py
@@ -19,7 +19,7 @@ from polymatrix.expression.expression import Expression, VariableExpression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate import ExpressionState
from polymatrix.polymatrix.mixins import PolyMatrixMixin
-from polymatrix.polymatrix.index import MonomialIndex
+from polymatrix.polymatrix.index import MonomialIndex, VariableIndex
from polymatrix.variable import Variable
from .abc import Problem, Constraint, Solver, Result
@@ -186,24 +186,29 @@ class SOSProblem(Problem):
Likewise the cost function must also be reduced to quadratic expression
here.
"""
-
+ variable_indices: set[VariableIndex] = set()
constraints: list[Constraint] = []
+
state, cost = self.cost.apply(state)
+ variable_indices.update(cost.variables())
# Compute the polymatrix of each constraint expression. Even though the
# result may not be used, this is necessary to account for the case of
# variables that are only present in the constraints, which the state
# object may not contain yet.
- # TODO: Is there a more efficient way? There has to be a better way.
for c in self.constraints:
- state, _ = c.expression.apply(state)
+ state, pm = c.expression.apply(state)
+ variable_indices.update(pm.variables())
+
+ variables = set(state.get_variable_from_variable_index(v)
+ for v in variable_indices)
# Collect variables
def is_optvariable(v):
return isinstance(v, OptVariable)
- polynomial_variables, variables = partition(is_optvariable, state.indices.keys())
+ polynomial_variables, variables = partition(is_optvariable, variables)
polynomial_variables = tuple(polynomial_variables) # because it is a generator
x = poly.v_stack((1,) + polynomial_variables)