diff options
Diffstat (limited to 'pipeline.py')
-rw-r--r-- | pipeline.py | 180 |
1 files changed, 180 insertions, 0 deletions
diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000..eda598d --- /dev/null +++ b/pipeline.py @@ -0,0 +1,180 @@ +# -- In the sum of squares library ------------------------------------------------------- +from abc import ABC, abstractmethod +from enum import Enum, auto +from typing import Callable, Self +from dataclasses import dataclass, field +from itertools import cycle +from functools import wraps + +try: + from typing import override +except ImportError: + from typing_extensions import override + + +class Solver(Enum): + """ Enum to select a solver """ + CVXOPT = auto() + + +@dataclass(frozen=True) +class OptResult(ABC): + """ Generic result from optimization problem """ + success: bool # last problem was solved successfully + + +class Problem(ABC): + """ Optimization Problem. """ + + @abstractmethod + def solve(self, solver: Solver) -> OptResult: + """ Solve the optimization problem """ + + +class SOSProblem(Problem): + @override + def solve(self, solver: Solver) -> OptResult: + raise NotImplementedError + + +Stage = Callable[[OptResult], Problem] +SolvableStage = Callable[[OptResult, Solver], OptResult] + +@dataclass +class HaltingPredicate: + """ This wrapper is sadly necessary because otherwise we can't use + `isinstance(stage, HaltingPredicate)` in the class below. """ + check: Callable[[OptResult], bool] + + def __str__(self): + return self.check.__name__ + + def __call__(self, res: OptResult): + return self.check(res) + + +@dataclass +class MultiStageProblem(Problem): + """ Pipeline for multi-state optimization problems. + + The pipeline is made of stages and halting predicates. If there are no halting + predicates, the pipeline runs the stages by passing the result of each stage to the + next. If there are halting predicates, the pipeline is repeated until one of the + halting predicates tells it to stop. + """ + initial: OptResult + solver: Solver = Solver.CVXOPT + stages: list[Stage | HaltingPredicate] = field(default_factory=list) + iterations: int = 0 + + # Magic methods + + def __str__(self): + i, lines = 0, ["Multi-Stage Problem:"] + for stage in self.stages: + if isinstance(stage, HaltingPredicate): + lines.append(f" halt? {stage}") + else: + lines.append(f" {i:02d} stage {stage.__name__}") + i += 1 + return "\n".join(lines) + + # Problem behaviour + + @override + def solve(self) -> OptResult: + """ Solve the multistage problem """ + if HaltingPredicate in map(type, self.stages): + return self._solve_repeating() + return self._solve_once() + + def _solve_once(self) -> OptResult: + self.iterations, result = 0, self.initial + for stage in self.stages: + result = stage(result, self.solver) + self.iterations += 1 + return result + + def _solve_repeating(self) -> OptResult: + self.iterations, result = 0, self.initial + for stage in cycle(self.stages): + if isinstance(stage, HaltingPredicate): + if stage(result): + break + else: + result = stage(result, self.solver) + return result + + # Wrappers + + @staticmethod + def stage(fn: Stage) -> SolvableStage: + """ Make a stage / step for the pipeline. """ + @wraps(fn) + def wrapper(res: OptResult, solver: Solver) -> OptResult: + return fn(res).solve(solver) + + return wrapper + + @staticmethod + def halt(fn: Callable[[OptResult], bool]) -> HaltingPredicate: + return HaltingPredicate(fn) + + # Pipeline construction (plumbing) + + def and_then(self, fn: SolvableStage) -> Self: + """ Add a stage to the pipeline that runs only if the previous stage + completed with success. """ + @wraps(fn) + def wrapper(res: OptResult, solver: Solver) -> OptResult: + if not res.success: + return res # do nothing + return fn(res, solver) + + self.stages.append(fn) + return self + + def or_else(self, fn: SolvableStage) -> Self: + """ Add a stage to the pipeline, that runs only if the previous stage + failed. """ + @wraps(fn) + def wrapper(res: OptResult, solver: Solver) -> OptResult: + if res.success: + return res # do nothing + return fn(res, solver) + + self.stages.append(wrapper) + return self + + def stop_if(self, predicate: HaltingPredicate) -> Self: + """ Add a predicate to stop the pipeline. """ + self.stages.append(predicate) + return self + +# -- In the user script ----------------------------------------------------------------- +# from sumofsquares import MultiStageProblem + +# Define stages +@MultiStageProblem.stage +def solve_controller(result: OptResult) -> Problem: + return SOSProblem() + +@MultiStageProblem.stage +def solve_cbf_clf(result: OptResult) -> Problem: + return SOSProblem() + +@MultiStageProblem.halt +def good_op_region(result: OptResult) -> bool: + return abs(result.roi) < 1e-3 + + +# Initialize multi-stage problem and define order +power_converter_prob = ( + MultiStageProblem(initial=OptResult) # Pass initialization in constructor + .and_then(solve_controller) + .stop_if(good_op_region) + .and_then(solve_cbf_clf) +) + +print(power_converter_prob) +# res: OptResult = power_converter_prob.solve() |