aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/index.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/mdpoly/index.py b/mdpoly/index.py
index 5ff28b6..5211245 100644
--- a/mdpoly/index.py
+++ b/mdpoly/index.py
@@ -160,8 +160,13 @@ class PolyIndex(tuple[PolyVarIndex]):
indices = ", ".join(map(repr, self))
return f"{name}({indices})"
- def __contains__(self, var_idx):
- return any(map(lambda e: e.var_idx == var_idx, self))
+ # FIXME: may need to convert to method to avoid breaking expected behaviour
+ def __contains__(self, index: PolyVarIndex | int) -> bool:
+ if isinstance(index, PolyVarIndex):
+ return tuple.__contains__(self, index)
+
+ if isinstance(index, int):
+ return any(map(lambda e: e.var_idx == index, self))
# -- Helper methods ---