# -- In the sum of squares library ------------------------------------------------------- from abc import ABC, abstractmethod from enum import Enum, auto from typing import Callable, Self from dataclasses import dataclass, field from itertools import cycle from functools import wraps try: from typing import override except ImportError: from typing_extensions import override class Solver(Enum): """ Enum to select a solver """ CVXOPT = auto() @dataclass(frozen=True) class OptResult(ABC): """ Generic result from optimization problem """ success: bool # last problem was solved successfully class Problem(ABC): """ Optimization Problem. """ @abstractmethod def solve(self, solver: Solver) -> OptResult: """ Solve the optimization problem """ class SOSProblem(Problem): @override def solve(self, solver: Solver) -> OptResult: raise NotImplementedError Stage = Callable[[OptResult], Problem] SolvableStage = Callable[[OptResult, Solver], OptResult] @dataclass class HaltingPredicate: """ This wrapper is sadly necessary because otherwise we can't use `isinstance(stage, HaltingPredicate)` in the class below. """ check: Callable[[OptResult], bool] def __str__(self): return self.check.__name__ def __call__(self, res: OptResult): return self.check(res) @dataclass class MultiStageProblem(Problem): """ Pipeline for multi-state optimization problems. The pipeline is made of stages and halting predicates. If there are no halting predicates, the pipeline runs the stages by passing the result of each stage to the next. If there are halting predicates, the pipeline is repeated until one of the halting predicates tells it to stop. """ initial: OptResult solver: Solver = Solver.CVXOPT stages: list[Stage | HaltingPredicate] = field(default_factory=list) iterations: int = 0 # Magic methods def __str__(self): i, lines = 0, ["Multi-Stage Problem:"] for stage in self.stages: if isinstance(stage, HaltingPredicate): lines.append(f" halt? {stage}") else: lines.append(f" {i:02d} stage {stage.__name__}") i += 1 return "\n".join(lines) # Problem behaviour @override def solve(self) -> OptResult: """ Solve the multistage problem """ if HaltingPredicate in map(type, self.stages): return self._solve_repeating() return self._solve_once() def _solve_once(self) -> OptResult: self.iterations, result = 0, self.initial for stage in self.stages: result = stage(result, self.solver) self.iterations += 1 return result def _solve_repeating(self) -> OptResult: self.iterations, result = 0, self.initial for stage in cycle(self.stages): if isinstance(stage, HaltingPredicate): if stage(result): break else: result = stage(result, self.solver) return result # Wrappers @staticmethod def stage(fn: Stage) -> SolvableStage: """ Make a stage / step for the pipeline. """ @wraps(fn) def wrapper(res: OptResult, solver: Solver) -> OptResult: return fn(res).solve(solver) return wrapper @staticmethod def halt(fn: Callable[[OptResult], bool]) -> HaltingPredicate: return HaltingPredicate(fn) # Pipeline construction (plumbing) def and_then(self, fn: SolvableStage) -> Self: """ Add a stage to the pipeline that runs only if the previous stage completed with success. """ @wraps(fn) def wrapper(res: OptResult, solver: Solver) -> OptResult: if not res.success: return res # do nothing return fn(res, solver) self.stages.append(fn) return self def or_else(self, fn: SolvableStage) -> Self: """ Add a stage to the pipeline, that runs only if the previous stage failed. """ @wraps(fn) def wrapper(res: OptResult, solver: Solver) -> OptResult: if res.success: return res # do nothing return fn(res, solver) self.stages.append(wrapper) return self def stop_if(self, predicate: HaltingPredicate) -> Self: """ Add a predicate to stop the pipeline. """ self.stages.append(predicate) return self # -- In the user script ----------------------------------------------------------------- # from sumofsquares import MultiStageProblem # Define stages @MultiStageProblem.stage def solve_controller(result: OptResult) -> Problem: return SOSProblem() @MultiStageProblem.stage def solve_cbf_clf(result: OptResult) -> Problem: return SOSProblem() @MultiStageProblem.halt def good_op_region(result: OptResult) -> bool: return abs(result.roi) < 1e-3 # Initialize multi-stage problem and define order power_converter_prob = ( MultiStageProblem(initial=OptResult) # Pass initialization in constructor .and_then(solve_controller) .stop_if(good_op_region) .and_then(solve_cbf_clf) ) print(power_converter_prob) # res: OptResult = power_converter_prob.solve()