diff options
-rw-r--r-- | sumofsquares/canon.py | 14 | ||||
-rw-r--r-- | sumofsquares/problems.py | 6 |
2 files changed, 10 insertions, 10 deletions
diff --git a/sumofsquares/canon.py b/sumofsquares/canon.py index 94759f2..02c9266 100644 --- a/sumofsquares/canon.py +++ b/sumofsquares/canon.py @@ -7,6 +7,7 @@ from dataclasses import replace from typing_extensions import override from polymatrix.expression.from_ import Expression +from polymatrix.expression.init import init_variable_expr from polymatrix.expressionstate import ExpressionState from .abc import Problem, Solver, Constraint, Result @@ -67,12 +68,9 @@ class Canonicalization(Problem): """ The problem that will be canonicalized """ @override - def solve(self, verbose: bool = False, state: ExpressionState | None = None) -> Result: - if state is None: - state = poly.make_state() - + def solve(self, state: ExpressionState, verbose: bool = False) -> tuple[ExpressionState, Result]: state, internal_prob = self.apply(state) - return internal_prob.solve(verbose) + return internal_prob.solve(state, verbose) @abstractmethod def apply(self, state: ExpressionState) -> tuple[ExpressionState, InternalSOSProblem]: @@ -139,10 +137,12 @@ class PutinarPSatz(Canonicalization): "Degree of domain polynomial is higher than constraint polynomial!" # Degree of multiplier must be even - if d % 2 != 0: + if d.read(state).scalar().constant() % 2 != 0: d += 1 - x = poly.v_stack((1,) + prob.polynomial_variables) + x = poly.v_stack((1,) + tuple( + init_variable_expr(v, state.get_shape(v)) + for v in prob.polynomial_variables)) # FIXME: need to check that there is not a \gamma_{i} already! # TODO: deterministically generate unique names diff --git a/sumofsquares/problems.py b/sumofsquares/problems.py index 14618fe..2ba9ee4 100644 --- a/sumofsquares/problems.py +++ b/sumofsquares/problems.py @@ -126,8 +126,8 @@ class ConicResult(Result): # Unwrap the expression symbol = var.underlying.symbol - if var not in self.values: - raise KeyError(f"There is no result for the variable {var}. " + if symbol not in self.values: + raise KeyError(f"There is no result for the variable {symbol}. " f"Was the problem successfully solved?") return self.values[symbol] @@ -415,4 +415,4 @@ class InternalSOSProblem(Problem): @override def solve(self, state: ExpressionState, verbose: bool = False) -> tuple[ExpressionState, Result]: - return state, self.to_conic_problem(state, verbose).solve(verbose) + return state, self.to_conic_problem(state=state, verbose=verbose).solve(verbose) |