summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-01-05 17:13:53 +0100
committerNao Pross <np@0hm.ch>2024-01-05 17:13:53 +0100
commit1b75a37d65102558915cc78edf63310dab5ceea1 (patch)
tree5ac7627595bebd10c5c04c7ddb5f741046ecdecb
parentTypos (diff)
downloadact4e-mcdp-1b75a37d65102558915cc78edf63310dab5ceea1.tar.gz
act4e-mcdp-1b75a37d65102558915cc78edf63310dab5ceea1.zip
Fix intersections
-rw-r--r--src/act4e_mcdp_solution/solver_dp.py23
1 files changed, 13 insertions, 10 deletions
diff --git a/src/act4e_mcdp_solution/solver_dp.py b/src/act4e_mcdp_solution/solver_dp.py
index b4b0461..2798d47 100644
--- a/src/act4e_mcdp_solution/solver_dp.py
+++ b/src/act4e_mcdp_solution/solver_dp.py
@@ -710,9 +710,11 @@ class DPSolver(DPSolverInterface):
return us
def intersect_pair(u1: UpperSet[RT], u2: UpperSet[RT]) -> UpperSet[RT]:
- # always get the bigger of the two minimals
- minimals: List[RT] = [m for m in u1.minimals if m in u2.minimals]
- return UpperSet.from_points(poset, minimals)
+ def poset_max(m1: RT, m2: RT) -> RT:
+ return m1 if poset.geq(m1, m2) else m2
+
+ minimals: List[RT] = starmap(poset_max, zip(u1.minimals, u2.minimals))
+ return UpperSet(minimals=list(minimals))
return reduce(intersect_pair, us)
@@ -740,14 +742,13 @@ class DPSolver(DPSolverInterface):
# Return antichain that is fixed point of phi
def kleene_ascent(phi) -> List[R1]:
# initialize with bottoms
- chain: List[RT] = dp.dp.R.global_minima().minimals
- prev_chain = None
+ chain: List[RT] = list(dp.dp.R.global_minima().minimals)
+ prev_chain = []
# Ascent procedure
while chain != prev_chain:
prev_chain = chain
chain = phi(chain)
- # assert chain, "No chain!"
# Take R1 out of RT
first = lambda r: r[0]
@@ -772,8 +773,11 @@ class DPSolver(DPSolverInterface):
return ls
def intersect_pair(l1: LowerSet[FT], l2: LowerSet[FT]) -> LowerSet[FT]:
- maximals: List[FT] = [m for m in l1.maximals if m in l2.maximals]
- return LowerSet.from_points(poset, maximals)
+ def poset_min(m1: FT, m2: FT) -> FT:
+ return m1 if poset.leq(m1, m2) else m2
+
+ maximals: List[FT] = starmap(poset_min, zip(l1.maximals, l2.maximals))
+ return LowerSet(maximals=list(maximals))
return reduce(intersect_pair, ls)
@@ -795,12 +799,11 @@ class DPSolver(DPSolverInterface):
def kleene_ascent(phi) -> List[F1]:
chain: List[FT] = dp.dp.F.global_minima().minimals
- prev_chain = None
+ prev_chain = []
while chain != prev_chain:
prev_chain = chain
chain = phi(chain)
- # assert chain, "No chain!"
first = lambda f: f[0]
chain: List[F1] = list(map(first, chain))