diff options
author | Nao Pross <np@0hm.ch> | 2024-01-03 15:53:29 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-01-03 16:48:30 +0100 |
commit | 5706d2c415c7a387470093ae38c0a9fcb49feb75 (patch) | |
tree | 4693e19d4e3c150b1fc368cce62c840d4706bbbd | |
parent | Pass everything in first two parts except for DPLoop (diff) | |
download | act4e-mcdp-5706d2c415c7a387470093ae38c0a9fcb49feb75.tar.gz act4e-mcdp-5706d2c415c7a387470093ae38c0a9fcb49feb75.zip |
Start implementing DPLoop
-rw-r--r-- | src/act4e_mcdp_solution/solver_dp.py | 84 |
1 files changed, 79 insertions, 5 deletions
diff --git a/src/act4e_mcdp_solution/solver_dp.py b/src/act4e_mcdp_solution/solver_dp.py index a95451f..4a573b3 100644 --- a/src/act4e_mcdp_solution/solver_dp.py +++ b/src/act4e_mcdp_solution/solver_dp.py @@ -1,8 +1,10 @@ -from typing import Optional, TypeVar +import typing +from typing import Optional, TypeVar, Tuple, List from act4e_mcdp import * # type: ignore from decimal import Decimal from functools import reduce -from typing import Tuple +from itertools import starmap + import itertools @@ -12,6 +14,7 @@ __all__ = [ X = TypeVar("X") +M = TypeVar("M") FT = TypeVar("FT") RT = TypeVar("RT") R1 = TypeVar("R1") @@ -624,7 +627,7 @@ class DPSolver(DPSolverInterface): # Similar to a case above f = query.functionality - min_r = r / dp.vu.value + min_r = f / dp.vu.value us = dp.R.largest_upperset_above(min_r) return Interval.degenerate(us) @@ -691,8 +694,64 @@ class DPSolver(DPSolverInterface): # As in the book, the intermediate goal is to define a function f such that # the solution is the least fixed point of f. - - raise NotImplementedError + + def uppersets_intersection(poset: Poset[RT], us: List[UpperSet[RT]]) -> UpperSet[RT]: + if len(us) < 2: + return us + + def poset_max(m1: RT, m2: RT) -> RT: + return m1 if poset.geq(m1, m2) else m2 + + def intersect_pair(u1: UpperSet[RT], u2: UpperSet[RT]) -> UpperSet[RT]: + # always get the bigger of the two minimals + minimals: List[RT] = list(starmap(poset_max, zip(u1.minimals, u2.minimals))) + return UpperSet.from_points(poset, minimals) + + return reduce(intersect_pair, us[1:], us[0]) + + # Map antichain to antichain + def phi(chain: List[RT]) -> List[RT]: + uppersets: List[UpperSet[RT]] = [] + for r in chain: + # Query for internal system that is looped + # with FT = F1 \otimes M and RT = R1 \otimes M + inner_f: FT = (query.functionality,) + r[1:] + inner_query: FixFunMinResQuery[FT] = FixFunMinResQuery(functionality=inner_f) + + # Take the pessimistic solution + min_rs: Interval[UpperSet[RT]] = self.solve_dp_FixFunMinRes(dp.dp, inner_query) + min_r: UpperSet[RT] = min_rs.pessimistic + + # compute intersection of \uparrow r with h_d(f, r) + up_r = UpperSet.principal(r) + inters = uppersets_intersection(dp.dp.R, [min_r, up_r]) + uppersets.append(inters) + + # Return antichain of upperset union (minimum) + union: UpperSet[RT] = UpperSet.union(uppersets, dp.dp.R) + return union.minimals + + # Return antichain that is fixed point of phi + def kleene_ascent(phi, maxiter=100) -> List[R1]: + # initialize with bottoms, + chain: List[RT] = dp.dp.R.global_maxima().maximals + prev_chain = None + + # Ascent procedure + while chain != prev_chain: + prev_chain = chain + chain = phi(chain) + + # Take R1 out of RT + first = lambda r: r[0] + chain: List[R1] = list(map(first, chain)) + return chain + + # Get chain that is a fixed point of phi and return upper set + chain: List[R1] = kleene_ascent(phi) + min_r_loop: UpperSet[R1] = UpperSet.from_points(dp.R, chain) + + return Interval.degenerate(min_r_loop) def solve_dp_FixResMaxFun_DPLoop2( self, dp: DPLoop2[F1, R1, object], query: FixResMaxFunQuery[R1] @@ -700,4 +759,19 @@ class DPSolver(DPSolverInterface): # Note: this is an advanced exercise. # Hint: same as above, but go the other way... + raise NotImplementedError + + def lowersets_intersection(poset: Poset[FT], ls: List[LowerSet[FT]]) -> LowerSet[FT]: + if len(ls) < 2: + return ls + + def intersect_pair(l1: LowerSet[FT], l2: LowerSet[FT]) -> LowerSet[LT]: + elements = set() + for m1 in l1.maximals: + if m1 in l2.maximals: + elements.add(m1) + + maximals = list(poset.maximals(elements)) + + |