summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--sumofsquares/canon.py14
-rw-r--r--sumofsquares/problems.py6
2 files changed, 10 insertions, 10 deletions
diff --git a/sumofsquares/canon.py b/sumofsquares/canon.py
index 94759f2..02c9266 100644
--- a/sumofsquares/canon.py
+++ b/sumofsquares/canon.py
@@ -7,6 +7,7 @@ from dataclasses import replace
from typing_extensions import override
from polymatrix.expression.from_ import Expression
+from polymatrix.expression.init import init_variable_expr
from polymatrix.expressionstate import ExpressionState
from .abc import Problem, Solver, Constraint, Result
@@ -67,12 +68,9 @@ class Canonicalization(Problem):
""" The problem that will be canonicalized """
@override
- def solve(self, verbose: bool = False, state: ExpressionState | None = None) -> Result:
- if state is None:
- state = poly.make_state()
-
+ def solve(self, state: ExpressionState, verbose: bool = False) -> tuple[ExpressionState, Result]:
state, internal_prob = self.apply(state)
- return internal_prob.solve(verbose)
+ return internal_prob.solve(state, verbose)
@abstractmethod
def apply(self, state: ExpressionState) -> tuple[ExpressionState, InternalSOSProblem]:
@@ -139,10 +137,12 @@ class PutinarPSatz(Canonicalization):
"Degree of domain polynomial is higher than constraint polynomial!"
# Degree of multiplier must be even
- if d % 2 != 0:
+ if d.read(state).scalar().constant() % 2 != 0:
d += 1
- x = poly.v_stack((1,) + prob.polynomial_variables)
+ x = poly.v_stack((1,) + tuple(
+ init_variable_expr(v, state.get_shape(v))
+ for v in prob.polynomial_variables))
# FIXME: need to check that there is not a \gamma_{i} already!
# TODO: deterministically generate unique names
diff --git a/sumofsquares/problems.py b/sumofsquares/problems.py
index 14618fe..2ba9ee4 100644
--- a/sumofsquares/problems.py
+++ b/sumofsquares/problems.py
@@ -126,8 +126,8 @@ class ConicResult(Result):
# Unwrap the expression
symbol = var.underlying.symbol
- if var not in self.values:
- raise KeyError(f"There is no result for the variable {var}. "
+ if symbol not in self.values:
+ raise KeyError(f"There is no result for the variable {symbol}. "
f"Was the problem successfully solved?")
return self.values[symbol]
@@ -415,4 +415,4 @@ class InternalSOSProblem(Problem):
@override
def solve(self, state: ExpressionState, verbose: bool = False) -> tuple[ExpressionState, Result]:
- return state, self.to_conic_problem(state, verbose).solve(verbose)
+ return state, self.to_conic_problem(state=state, verbose=verbose).solve(verbose)