aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-10 23:30:32 +0100
committerNao Pross <np@0hm.ch>2024-03-10 23:32:08 +0100
commit7c091b9ed2fcfb6407ed7ba4c28d2ba076e0f827 (patch)
tree1091e72567f1eb8be915bc47bb3ea6eeaf1a5dbb
parentFix PolyRingExpr.__neg__ (diff)
downloadmdpoly-7c091b9ed2fcfb6407ed7ba4c28d2ba076e0f827.tar.gz
mdpoly-7c091b9ed2fcfb6407ed7ba4c28d2ba076e0f827.zip
Fix bug in PolyPartialDiff and add PolyRingExpr.diff(wrt)
Forgot to premultiply with old exponent, d_x (x^2) = 2 * x, bug was that d_x (x^2) gives x, which is wrong
-rw-r--r--mdpoly/algebra.py13
-rw-r--r--mdpoly/index.py13
2 files changed, 19 insertions, 7 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 1fbabda..b45f0e0 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -204,6 +204,11 @@ class PolyRingExpr(Expr):
"polynomial ring. Only constants and parameters are allowed.")
return PolyExp(self, other)
+ # -- Other mathematical operations ---
+
+ def diff(self, wrt: PolyVar) -> PolyPartialDiff:
+ return PolyPartialDiff(self, wrt)
+
@dataclassabc(frozen=True, repr=False)
class PolyVar(Var, PolyRingExpr):
@@ -360,8 +365,12 @@ class PolyPartialDiff(UnaryOp, PolyRingExpr):
wrt = state.index(self.wrt)
for term in lrepr.terms(entry):
- if newterm := PolyIndex.differentiate(term, wrt):
- r.set(entry, newterm, lrepr.at(entry, term))
+ # sadly, walrus operator does not support unpacking
+ # https://bugs.python.org/issue43143
+ # if ((newterm, fac) := PolyIndex.differentiate(term, wrt)):
+ if result := PolyIndex.differentiate(term, wrt):
+ newterm, fac = result
+ r.set(entry, newterm, fac * lrepr.at(entry, term))
return r, state
diff --git a/mdpoly/index.py b/mdpoly/index.py
index c085783..5ff28b6 100644
--- a/mdpoly/index.py
+++ b/mdpoly/index.py
@@ -233,12 +233,15 @@ class PolyIndex(tuple[PolyVarIndex]):
return cls.sort(cls.from_dict(result))
@classmethod
- def differentiate(cls, index: Self, wrt: int) -> Optional[Self]:
+ def differentiate(cls, index: Self, wrt: int) -> Optional[tuple[Self, Number]]:
""" Compute the index after differentiation
For example if the index is :math:`xy^2` and ``wrt`` is the index of
- :math:`y` this function returns the index of :math:`xy`. In particular
- this function takes care of the edge cases of differentiating a
+ :math:`y` this function returns the index of :math:`xy`. The function
+ also returns the exponent before it was differentiated, hence in the
+ example 2.
+
+ This function takes care of the edge cases of differentiating a
constant and differentiating linear terms. Specifically, if the index
is a constant ``None`` is returned.
"""
@@ -257,8 +260,8 @@ class PolyIndex(tuple[PolyVarIndex]):
# Check if is linear term
if isclose(with_wrt_var.power, 1.):
- return cls.sort(tuple(others) + cls.constant())
+ return cls.sort(others), 1
# Decrease exponent
new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1))
- return cls.sort(tuple(others) + (new_idx,))
+ return cls.sort(tuple(others) + (new_idx,)), with_wrt_var.power