aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/algebra.py13
-rw-r--r--mdpoly/index.py13
2 files changed, 19 insertions, 7 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 1fbabda..b45f0e0 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -204,6 +204,11 @@ class PolyRingExpr(Expr):
"polynomial ring. Only constants and parameters are allowed.")
return PolyExp(self, other)
+ # -- Other mathematical operations ---
+
+ def diff(self, wrt: PolyVar) -> PolyPartialDiff:
+ return PolyPartialDiff(self, wrt)
+
@dataclassabc(frozen=True, repr=False)
class PolyVar(Var, PolyRingExpr):
@@ -360,8 +365,12 @@ class PolyPartialDiff(UnaryOp, PolyRingExpr):
wrt = state.index(self.wrt)
for term in lrepr.terms(entry):
- if newterm := PolyIndex.differentiate(term, wrt):
- r.set(entry, newterm, lrepr.at(entry, term))
+ # sadly, walrus operator does not support unpacking
+ # https://bugs.python.org/issue43143
+ # if ((newterm, fac) := PolyIndex.differentiate(term, wrt)):
+ if result := PolyIndex.differentiate(term, wrt):
+ newterm, fac = result
+ r.set(entry, newterm, fac * lrepr.at(entry, term))
return r, state
diff --git a/mdpoly/index.py b/mdpoly/index.py
index c085783..5ff28b6 100644
--- a/mdpoly/index.py
+++ b/mdpoly/index.py
@@ -233,12 +233,15 @@ class PolyIndex(tuple[PolyVarIndex]):
return cls.sort(cls.from_dict(result))
@classmethod
- def differentiate(cls, index: Self, wrt: int) -> Optional[Self]:
+ def differentiate(cls, index: Self, wrt: int) -> Optional[tuple[Self, Number]]:
""" Compute the index after differentiation
For example if the index is :math:`xy^2` and ``wrt`` is the index of
- :math:`y` this function returns the index of :math:`xy`. In particular
- this function takes care of the edge cases of differentiating a
+ :math:`y` this function returns the index of :math:`xy`. The function
+ also returns the exponent before it was differentiated, hence in the
+ example 2.
+
+ This function takes care of the edge cases of differentiating a
constant and differentiating linear terms. Specifically, if the index
is a constant ``None`` is returned.
"""
@@ -257,8 +260,8 @@ class PolyIndex(tuple[PolyVarIndex]):
# Check if is linear term
if isclose(with_wrt_var.power, 1.):
- return cls.sort(tuple(others) + cls.constant())
+ return cls.sort(others), 1
# Decrease exponent
new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1))
- return cls.sort(tuple(others) + (new_idx,))
+ return cls.sort(tuple(others) + (new_idx,)), with_wrt_var.power