diff options
author | Nao Pross <np@0hm.ch> | 2024-01-05 17:13:53 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-01-05 17:13:53 +0100 |
commit | 1b75a37d65102558915cc78edf63310dab5ceea1 (patch) | |
tree | 5ac7627595bebd10c5c04c7ddb5f741046ecdecb | |
parent | Typos (diff) | |
download | act4e-mcdp-1b75a37d65102558915cc78edf63310dab5ceea1.tar.gz act4e-mcdp-1b75a37d65102558915cc78edf63310dab5ceea1.zip |
Fix intersections
-rw-r--r-- | src/act4e_mcdp_solution/solver_dp.py | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/src/act4e_mcdp_solution/solver_dp.py b/src/act4e_mcdp_solution/solver_dp.py index b4b0461..2798d47 100644 --- a/src/act4e_mcdp_solution/solver_dp.py +++ b/src/act4e_mcdp_solution/solver_dp.py @@ -710,9 +710,11 @@ class DPSolver(DPSolverInterface): return us def intersect_pair(u1: UpperSet[RT], u2: UpperSet[RT]) -> UpperSet[RT]: - # always get the bigger of the two minimals - minimals: List[RT] = [m for m in u1.minimals if m in u2.minimals] - return UpperSet.from_points(poset, minimals) + def poset_max(m1: RT, m2: RT) -> RT: + return m1 if poset.geq(m1, m2) else m2 + + minimals: List[RT] = starmap(poset_max, zip(u1.minimals, u2.minimals)) + return UpperSet(minimals=list(minimals)) return reduce(intersect_pair, us) @@ -740,14 +742,13 @@ class DPSolver(DPSolverInterface): # Return antichain that is fixed point of phi def kleene_ascent(phi) -> List[R1]: # initialize with bottoms - chain: List[RT] = dp.dp.R.global_minima().minimals - prev_chain = None + chain: List[RT] = list(dp.dp.R.global_minima().minimals) + prev_chain = [] # Ascent procedure while chain != prev_chain: prev_chain = chain chain = phi(chain) - # assert chain, "No chain!" # Take R1 out of RT first = lambda r: r[0] @@ -772,8 +773,11 @@ class DPSolver(DPSolverInterface): return ls def intersect_pair(l1: LowerSet[FT], l2: LowerSet[FT]) -> LowerSet[FT]: - maximals: List[FT] = [m for m in l1.maximals if m in l2.maximals] - return LowerSet.from_points(poset, maximals) + def poset_min(m1: FT, m2: FT) -> FT: + return m1 if poset.leq(m1, m2) else m2 + + maximals: List[FT] = starmap(poset_min, zip(l1.maximals, l2.maximals)) + return LowerSet(maximals=list(maximals)) return reduce(intersect_pair, ls) @@ -795,12 +799,11 @@ class DPSolver(DPSolverInterface): def kleene_ascent(phi) -> List[F1]: chain: List[FT] = dp.dp.F.global_minima().minimals - prev_chain = None + prev_chain = [] while chain != prev_chain: prev_chain = chain chain = phi(chain) - # assert chain, "No chain!" first = lambda f: f[0] chain: List[F1] = list(map(first, chain)) |