summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-01-03 15:53:29 +0100
committerNao Pross <np@0hm.ch>2024-01-03 16:48:30 +0100
commit5706d2c415c7a387470093ae38c0a9fcb49feb75 (patch)
tree4693e19d4e3c150b1fc368cce62c840d4706bbbd
parentPass everything in first two parts except for DPLoop (diff)
downloadact4e-mcdp-5706d2c415c7a387470093ae38c0a9fcb49feb75.tar.gz
act4e-mcdp-5706d2c415c7a387470093ae38c0a9fcb49feb75.zip
Start implementing DPLoop
-rw-r--r--src/act4e_mcdp_solution/solver_dp.py84
1 files changed, 79 insertions, 5 deletions
diff --git a/src/act4e_mcdp_solution/solver_dp.py b/src/act4e_mcdp_solution/solver_dp.py
index a95451f..4a573b3 100644
--- a/src/act4e_mcdp_solution/solver_dp.py
+++ b/src/act4e_mcdp_solution/solver_dp.py
@@ -1,8 +1,10 @@
-from typing import Optional, TypeVar
+import typing
+from typing import Optional, TypeVar, Tuple, List
from act4e_mcdp import * # type: ignore
from decimal import Decimal
from functools import reduce
-from typing import Tuple
+from itertools import starmap
+
import itertools
@@ -12,6 +14,7 @@ __all__ = [
X = TypeVar("X")
+M = TypeVar("M")
FT = TypeVar("FT")
RT = TypeVar("RT")
R1 = TypeVar("R1")
@@ -624,7 +627,7 @@ class DPSolver(DPSolverInterface):
# Similar to a case above
f = query.functionality
- min_r = r / dp.vu.value
+ min_r = f / dp.vu.value
us = dp.R.largest_upperset_above(min_r)
return Interval.degenerate(us)
@@ -691,8 +694,64 @@ class DPSolver(DPSolverInterface):
# As in the book, the intermediate goal is to define a function f such that
# the solution is the least fixed point of f.
-
- raise NotImplementedError
+
+ def uppersets_intersection(poset: Poset[RT], us: List[UpperSet[RT]]) -> UpperSet[RT]:
+ if len(us) < 2:
+ return us
+
+ def poset_max(m1: RT, m2: RT) -> RT:
+ return m1 if poset.geq(m1, m2) else m2
+
+ def intersect_pair(u1: UpperSet[RT], u2: UpperSet[RT]) -> UpperSet[RT]:
+ # always get the bigger of the two minimals
+ minimals: List[RT] = list(starmap(poset_max, zip(u1.minimals, u2.minimals)))
+ return UpperSet.from_points(poset, minimals)
+
+ return reduce(intersect_pair, us[1:], us[0])
+
+ # Map antichain to antichain
+ def phi(chain: List[RT]) -> List[RT]:
+ uppersets: List[UpperSet[RT]] = []
+ for r in chain:
+ # Query for internal system that is looped
+ # with FT = F1 \otimes M and RT = R1 \otimes M
+ inner_f: FT = (query.functionality,) + r[1:]
+ inner_query: FixFunMinResQuery[FT] = FixFunMinResQuery(functionality=inner_f)
+
+ # Take the pessimistic solution
+ min_rs: Interval[UpperSet[RT]] = self.solve_dp_FixFunMinRes(dp.dp, inner_query)
+ min_r: UpperSet[RT] = min_rs.pessimistic
+
+ # compute intersection of \uparrow r with h_d(f, r)
+ up_r = UpperSet.principal(r)
+ inters = uppersets_intersection(dp.dp.R, [min_r, up_r])
+ uppersets.append(inters)
+
+ # Return antichain of upperset union (minimum)
+ union: UpperSet[RT] = UpperSet.union(uppersets, dp.dp.R)
+ return union.minimals
+
+ # Return antichain that is fixed point of phi
+ def kleene_ascent(phi, maxiter=100) -> List[R1]:
+ # initialize with bottoms,
+ chain: List[RT] = dp.dp.R.global_maxima().maximals
+ prev_chain = None
+
+ # Ascent procedure
+ while chain != prev_chain:
+ prev_chain = chain
+ chain = phi(chain)
+
+ # Take R1 out of RT
+ first = lambda r: r[0]
+ chain: List[R1] = list(map(first, chain))
+ return chain
+
+ # Get chain that is a fixed point of phi and return upper set
+ chain: List[R1] = kleene_ascent(phi)
+ min_r_loop: UpperSet[R1] = UpperSet.from_points(dp.R, chain)
+
+ return Interval.degenerate(min_r_loop)
def solve_dp_FixResMaxFun_DPLoop2(
self, dp: DPLoop2[F1, R1, object], query: FixResMaxFunQuery[R1]
@@ -700,4 +759,19 @@ class DPSolver(DPSolverInterface):
# Note: this is an advanced exercise.
# Hint: same as above, but go the other way...
+
raise NotImplementedError
+
+ def lowersets_intersection(poset: Poset[FT], ls: List[LowerSet[FT]]) -> LowerSet[FT]:
+ if len(ls) < 2:
+ return ls
+
+ def intersect_pair(l1: LowerSet[FT], l2: LowerSet[FT]) -> LowerSet[LT]:
+ elements = set()
+ for m1 in l1.maximals:
+ if m1 in l2.maximals:
+ elements.add(m1)
+
+ maximals = list(poset.maximals(elements))
+
+