aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/index.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/mdpoly/index.py b/mdpoly/index.py
index 2837a46..02318a4 100644
--- a/mdpoly/index.py
+++ b/mdpoly/index.py
@@ -148,6 +148,9 @@ class PolyIndex(tuple[PolyVarIndex]):
indices = ", ".join(map(repr, self))
return f"{name}({indices})"
+ def __contains__(self, var_idx):
+ return any(map(lambda e: e.var_idx == var_idx, self))
+
@classmethod
def from_dict(cls, d) -> Self:
""" Construct an index froma dictionary, where the keys are the
@@ -209,6 +212,9 @@ class PolyIndex(tuple[PolyVarIndex]):
if cls.is_constant(index):
return None
+ if wrt not in index:
+ return None
+
def is_wrt_var(idx: PolyVarIndex) -> bool:
return idx.var_idx == wrt