diff options
author | Nao Pross <np@0hm.ch> | 2024-06-02 15:44:12 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-06-02 15:44:12 +0200 |
commit | 9137149d33eb53d74e940310f27d935fad10a813 (patch) | |
tree | 498b72830b10e2a2b794cd5591aa43993d7df2d9 | |
parent | Fix PSatz canon multiplier degrees (diff) | |
download | sumofsquares-9137149d33eb53d74e940310f27d935fad10a813.tar.gz sumofsquares-9137149d33eb53d74e940310f27d935fad10a813.zip |
Optimize size of SDP by taking variables from expressions instead of state
As discussed during meeting nr. 12
-rw-r--r-- | sumofsquares/problems.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/sumofsquares/problems.py b/sumofsquares/problems.py index bfb13c3..f4f3d7b 100644 --- a/sumofsquares/problems.py +++ b/sumofsquares/problems.py @@ -19,7 +19,7 @@ from polymatrix.expression.expression import Expression, VariableExpression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin -from polymatrix.polymatrix.index import MonomialIndex +from polymatrix.polymatrix.index import MonomialIndex, VariableIndex from polymatrix.variable import Variable from .abc import Problem, Constraint, Solver, Result @@ -186,24 +186,29 @@ class SOSProblem(Problem): Likewise the cost function must also be reduced to quadratic expression here. """ - + variable_indices: set[VariableIndex] = set() constraints: list[Constraint] = [] + state, cost = self.cost.apply(state) + variable_indices.update(cost.variables()) # Compute the polymatrix of each constraint expression. Even though the # result may not be used, this is necessary to account for the case of # variables that are only present in the constraints, which the state # object may not contain yet. - # TODO: Is there a more efficient way? There has to be a better way. for c in self.constraints: - state, _ = c.expression.apply(state) + state, pm = c.expression.apply(state) + variable_indices.update(pm.variables()) + + variables = set(state.get_variable_from_variable_index(v) + for v in variable_indices) # Collect variables def is_optvariable(v): return isinstance(v, OptVariable) - polynomial_variables, variables = partition(is_optvariable, state.indices.keys()) + polynomial_variables, variables = partition(is_optvariable, variables) polynomial_variables = tuple(polynomial_variables) # because it is a generator x = poly.v_stack((1,) + polynomial_variables) |