summaryrefslogtreecommitdiffstats
path: root/pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'pipeline.py')
-rw-r--r--pipeline.py180
1 files changed, 180 insertions, 0 deletions
diff --git a/pipeline.py b/pipeline.py
new file mode 100644
index 0000000..eda598d
--- /dev/null
+++ b/pipeline.py
@@ -0,0 +1,180 @@
+# -- In the sum of squares library -------------------------------------------------------
+from abc import ABC, abstractmethod
+from enum import Enum, auto
+from typing import Callable, Self
+from dataclasses import dataclass, field
+from itertools import cycle
+from functools import wraps
+
+try:
+ from typing import override
+except ImportError:
+ from typing_extensions import override
+
+
+class Solver(Enum):
+ """ Enum to select a solver """
+ CVXOPT = auto()
+
+
+@dataclass(frozen=True)
+class OptResult(ABC):
+ """ Generic result from optimization problem """
+ success: bool # last problem was solved successfully
+
+
+class Problem(ABC):
+ """ Optimization Problem. """
+
+ @abstractmethod
+ def solve(self, solver: Solver) -> OptResult:
+ """ Solve the optimization problem """
+
+
+class SOSProblem(Problem):
+ @override
+ def solve(self, solver: Solver) -> OptResult:
+ raise NotImplementedError
+
+
+Stage = Callable[[OptResult], Problem]
+SolvableStage = Callable[[OptResult, Solver], OptResult]
+
+@dataclass
+class HaltingPredicate:
+ """ This wrapper is sadly necessary because otherwise we can't use
+ `isinstance(stage, HaltingPredicate)` in the class below. """
+ check: Callable[[OptResult], bool]
+
+ def __str__(self):
+ return self.check.__name__
+
+ def __call__(self, res: OptResult):
+ return self.check(res)
+
+
+@dataclass
+class MultiStageProblem(Problem):
+ """ Pipeline for multi-state optimization problems.
+
+ The pipeline is made of stages and halting predicates. If there are no halting
+ predicates, the pipeline runs the stages by passing the result of each stage to the
+ next. If there are halting predicates, the pipeline is repeated until one of the
+ halting predicates tells it to stop.
+ """
+ initial: OptResult
+ solver: Solver = Solver.CVXOPT
+ stages: list[Stage | HaltingPredicate] = field(default_factory=list)
+ iterations: int = 0
+
+ # Magic methods
+
+ def __str__(self):
+ i, lines = 0, ["Multi-Stage Problem:"]
+ for stage in self.stages:
+ if isinstance(stage, HaltingPredicate):
+ lines.append(f" halt? {stage}")
+ else:
+ lines.append(f" {i:02d} stage {stage.__name__}")
+ i += 1
+ return "\n".join(lines)
+
+ # Problem behaviour
+
+ @override
+ def solve(self) -> OptResult:
+ """ Solve the multistage problem """
+ if HaltingPredicate in map(type, self.stages):
+ return self._solve_repeating()
+ return self._solve_once()
+
+ def _solve_once(self) -> OptResult:
+ self.iterations, result = 0, self.initial
+ for stage in self.stages:
+ result = stage(result, self.solver)
+ self.iterations += 1
+ return result
+
+ def _solve_repeating(self) -> OptResult:
+ self.iterations, result = 0, self.initial
+ for stage in cycle(self.stages):
+ if isinstance(stage, HaltingPredicate):
+ if stage(result):
+ break
+ else:
+ result = stage(result, self.solver)
+ return result
+
+ # Wrappers
+
+ @staticmethod
+ def stage(fn: Stage) -> SolvableStage:
+ """ Make a stage / step for the pipeline. """
+ @wraps(fn)
+ def wrapper(res: OptResult, solver: Solver) -> OptResult:
+ return fn(res).solve(solver)
+
+ return wrapper
+
+ @staticmethod
+ def halt(fn: Callable[[OptResult], bool]) -> HaltingPredicate:
+ return HaltingPredicate(fn)
+
+ # Pipeline construction (plumbing)
+
+ def and_then(self, fn: SolvableStage) -> Self:
+ """ Add a stage to the pipeline that runs only if the previous stage
+ completed with success. """
+ @wraps(fn)
+ def wrapper(res: OptResult, solver: Solver) -> OptResult:
+ if not res.success:
+ return res # do nothing
+ return fn(res, solver)
+
+ self.stages.append(fn)
+ return self
+
+ def or_else(self, fn: SolvableStage) -> Self:
+ """ Add a stage to the pipeline, that runs only if the previous stage
+ failed. """
+ @wraps(fn)
+ def wrapper(res: OptResult, solver: Solver) -> OptResult:
+ if res.success:
+ return res # do nothing
+ return fn(res, solver)
+
+ self.stages.append(wrapper)
+ return self
+
+ def stop_if(self, predicate: HaltingPredicate) -> Self:
+ """ Add a predicate to stop the pipeline. """
+ self.stages.append(predicate)
+ return self
+
+# -- In the user script -----------------------------------------------------------------
+# from sumofsquares import MultiStageProblem
+
+# Define stages
+@MultiStageProblem.stage
+def solve_controller(result: OptResult) -> Problem:
+ return SOSProblem()
+
+@MultiStageProblem.stage
+def solve_cbf_clf(result: OptResult) -> Problem:
+ return SOSProblem()
+
+@MultiStageProblem.halt
+def good_op_region(result: OptResult) -> bool:
+ return abs(result.roi) < 1e-3
+
+
+# Initialize multi-stage problem and define order
+power_converter_prob = (
+ MultiStageProblem(initial=OptResult) # Pass initialization in constructor
+ .and_then(solve_controller)
+ .stop_if(good_op_region)
+ .and_then(solve_cbf_clf)
+)
+
+print(power_converter_prob)
+# res: OptResult = power_converter_prob.solve()