diff options
-rw-r--r-- | mdpoly/algebra.py | 13 | ||||
-rw-r--r-- | mdpoly/index.py | 13 |
2 files changed, 19 insertions, 7 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 1fbabda..b45f0e0 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -204,6 +204,11 @@ class PolyRingExpr(Expr): "polynomial ring. Only constants and parameters are allowed.") return PolyExp(self, other) + # -- Other mathematical operations --- + + def diff(self, wrt: PolyVar) -> PolyPartialDiff: + return PolyPartialDiff(self, wrt) + @dataclassabc(frozen=True, repr=False) class PolyVar(Var, PolyRingExpr): @@ -360,8 +365,12 @@ class PolyPartialDiff(UnaryOp, PolyRingExpr): wrt = state.index(self.wrt) for term in lrepr.terms(entry): - if newterm := PolyIndex.differentiate(term, wrt): - r.set(entry, newterm, lrepr.at(entry, term)) + # sadly, walrus operator does not support unpacking + # https://bugs.python.org/issue43143 + # if ((newterm, fac) := PolyIndex.differentiate(term, wrt)): + if result := PolyIndex.differentiate(term, wrt): + newterm, fac = result + r.set(entry, newterm, fac * lrepr.at(entry, term)) return r, state diff --git a/mdpoly/index.py b/mdpoly/index.py index c085783..5ff28b6 100644 --- a/mdpoly/index.py +++ b/mdpoly/index.py @@ -233,12 +233,15 @@ class PolyIndex(tuple[PolyVarIndex]): return cls.sort(cls.from_dict(result)) @classmethod - def differentiate(cls, index: Self, wrt: int) -> Optional[Self]: + def differentiate(cls, index: Self, wrt: int) -> Optional[tuple[Self, Number]]: """ Compute the index after differentiation For example if the index is :math:`xy^2` and ``wrt`` is the index of - :math:`y` this function returns the index of :math:`xy`. In particular - this function takes care of the edge cases of differentiating a + :math:`y` this function returns the index of :math:`xy`. The function + also returns the exponent before it was differentiated, hence in the + example 2. + + This function takes care of the edge cases of differentiating a constant and differentiating linear terms. Specifically, if the index is a constant ``None`` is returned. """ @@ -257,8 +260,8 @@ class PolyIndex(tuple[PolyVarIndex]): # Check if is linear term if isclose(with_wrt_var.power, 1.): - return cls.sort(tuple(others) + cls.constant()) + return cls.sort(others), 1 # Decrease exponent new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1)) - return cls.sort(tuple(others) + (new_idx,)) + return cls.sort(tuple(others) + (new_idx,)), with_wrt_var.power |