aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-03 14:31:41 +0100
committerNao Pross <np@0hm.ch>2024-03-03 14:31:41 +0100
commit2d0570099be4ae94f4e04777613e756081c311df (patch)
treee29e07bafad9c0d616bf56a8a03356ba6d9401e7
parentDisallow indexing of non-variables in state (diff)
downloadmdpoly-2d0570099be4ae94f4e04777613e756081c311df.tar.gz
mdpoly-2d0570099be4ae94f4e04777613e756081c311df.zip
Implement partial derivative
-rw-r--r--mdpoly/algebra.py29
-rw-r--r--mdpoly/types.py34
-rw-r--r--mdpoly/util.py12
3 files changed, 69 insertions, 6 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 12c9d10..6d245c6 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -1,7 +1,7 @@
""" Algebraic Structures for Expressions """
from .abc import Expr, Repr
-from .leaves import Nothing, Const, Param
+from .leaves import Nothing, Const, Param, Var
from .errors import AlgebraicError
from .state import State
from .types import MatrixIndex, PolyIndex
@@ -35,6 +35,7 @@ def unary_operator(cls: Expr) -> Expr:
def __init_unary__(self, left, *args, **kwargs):
__init_cls__(self, *args, **kwargs)
self.left, self.right = left, Nothing
+ self.inner = left
cls.__init__ = __init_unary__
return cls
@@ -110,7 +111,7 @@ class ReducibleExpr(HasRepr, Protocol):
"""
def to_repr(self, repr_type: type, state: State) -> Repr:
- """ See :py:meth:`HasRepr.to_repr` """
+ """ See :py:meth:`mdpoly.representations.HasRepr.to_repr` """
return self.reduce().to_repr(repr_type, state)
@abstractmethod
@@ -262,10 +263,32 @@ class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra):
@unary_operator
+class PartialDiff(Expr, HasRepr, PolyRingAlgebra):
+ """ Partial differentiation of scalar polynomials """
+ def __init__(self, with_respect_to: Var):
+ self.wrt = with_respect_to
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
+ r = repr_type()
+ lrepr, state = self.left.to_repr(repr_type, state)
+
+ entry = MatrixIndex.scalar
+ wrt = state.index(self.wrt)
+
+ for term in lrepr.terms(entry):
+ if newterm := PolyIndex.differentiate(term, wrt):
+ r.set(entry, newterm, lrepr.at(entry, term))
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"(∂_{self.wrt} {self.inner})"
+
+@unary_operator
class Diff(Expr, PolyRingAlgebra):
""" Total differentiation. """
-
# ┏┳┓┏━┓╺┳╸┏━┓╻╻ ╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓
# ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫
# ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹
diff --git a/mdpoly/types.py b/mdpoly/types.py
index 1132260..1605aa8 100644
--- a/mdpoly/types.py
+++ b/mdpoly/types.py
@@ -2,9 +2,10 @@ from __future__ import annotations
from .constants import NUMERICS_EPS
from .errors import InvalidShape
+from .util import partition
from itertools import filterfalse
-from typing import Self, NamedTuple, Iterable, TYPE_CHECKING
+from typing import Self, NamedTuple, Iterable, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .state import State
@@ -161,8 +162,7 @@ class PolyIndex(tuple[PolyVarIndex]):
@classmethod
def product(cls, left: Self, right: Self) -> Self:
- """
- Compute index of product.
+ """ Compute index of product.
For example if left is the index of :math:`xy` and right is the index
of :math:`y^2` this functions returns the index of :math:`xy^3`.
@@ -181,3 +181,31 @@ class PolyIndex(tuple[PolyVarIndex]):
result[idx] += power
return cls.sort(cls.from_dict(result))
+
+ @classmethod
+ def differentiate(cls, index: Self, wrt: int) -> Optional[Self]:
+ """ Compute the index of differentiation
+
+ For example if the index is :math:`xy^2` and ``wrt`` is the index of
+ :math:`y` this function returns the index of :math:`xy`. In particular
+ this function takes care of the edge cases of differentiating a
+ constant and differentiating linear terms. Specifically, if the index
+ is a constant ``None`` is returned.
+ """
+
+ if cls.is_constant(index):
+ return None
+
+ def is_wrt_var(idx: PolyVarIndex) -> bool:
+ return idx.var_idx == wrt
+
+ others, with_wrt_var = partition(is_wrt_var, index)
+ with_wrt_var, *_ = with_wrt_var # Take first, as should be only one
+
+ # Check if is linear term
+ if (with_wrt_var.power - 1) < NUMERICS_EPS:
+ return cls.sort(tuple(others) + cls.constant)
+
+ # Decrease exponent
+ new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1))
+ return cls.sort(tuple(others) + (new_idx,))
diff --git a/mdpoly/util.py b/mdpoly/util.py
new file mode 100644
index 0000000..f1d5de2
--- /dev/null
+++ b/mdpoly/util.py
@@ -0,0 +1,12 @@
+from itertools import tee, filterfalse
+
+def partition(pred, iterable):
+ """Partition entries into false entries and true entries.
+
+ If *pred* is slow, consider wrapping it with functools.lru_cache().
+ """
+ # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
+ t1, t2 = tee(iterable)
+ return filterfalse(pred, t1), filter(pred, t2)
+
+