From 0c10239db7b4cd83ec523953cba0ff7fd0057c2c Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Wed, 6 Mar 2024 18:42:23 +0100 Subject: Fix missing case for PolyIndex.differentiate --- mdpoly/index.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mdpoly/index.py b/mdpoly/index.py index 2837a46..02318a4 100644 --- a/mdpoly/index.py +++ b/mdpoly/index.py @@ -148,6 +148,9 @@ class PolyIndex(tuple[PolyVarIndex]): indices = ", ".join(map(repr, self)) return f"{name}({indices})" + def __contains__(self, var_idx): + return any(map(lambda e: e.var_idx == var_idx, self)) + @classmethod def from_dict(cls, d) -> Self: """ Construct an index froma dictionary, where the keys are the @@ -209,6 +212,9 @@ class PolyIndex(tuple[PolyVarIndex]): if cls.is_constant(index): return None + if wrt not in index: + return None + def is_wrt_var(idx: PolyVarIndex) -> bool: return idx.var_idx == wrt -- cgit v1.2.1