diff options
-rw-r--r-- | mdpoly/algebra.py | 29 | ||||
-rw-r--r-- | mdpoly/types.py | 34 | ||||
-rw-r--r-- | mdpoly/util.py | 12 |
3 files changed, 69 insertions, 6 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 12c9d10..6d245c6 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -1,7 +1,7 @@ """ Algebraic Structures for Expressions """ from .abc import Expr, Repr -from .leaves import Nothing, Const, Param +from .leaves import Nothing, Const, Param, Var from .errors import AlgebraicError from .state import State from .types import MatrixIndex, PolyIndex @@ -35,6 +35,7 @@ def unary_operator(cls: Expr) -> Expr: def __init_unary__(self, left, *args, **kwargs): __init_cls__(self, *args, **kwargs) self.left, self.right = left, Nothing + self.inner = left cls.__init__ = __init_unary__ return cls @@ -110,7 +111,7 @@ class ReducibleExpr(HasRepr, Protocol): """ def to_repr(self, repr_type: type, state: State) -> Repr: - """ See :py:meth:`HasRepr.to_repr` """ + """ See :py:meth:`mdpoly.representations.HasRepr.to_repr` """ return self.reduce().to_repr(repr_type, state) @abstractmethod @@ -262,10 +263,32 @@ class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra): @unary_operator +class PartialDiff(Expr, HasRepr, PolyRingAlgebra): + """ Partial differentiation of scalar polynomials """ + def __init__(self, with_respect_to: Var): + self.wrt = with_respect_to + + def to_repr(self, repr_type: type, state: State) -> Repr: + """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ + r = repr_type() + lrepr, state = self.left.to_repr(repr_type, state) + + entry = MatrixIndex.scalar + wrt = state.index(self.wrt) + + for term in lrepr.terms(entry): + if newterm := PolyIndex.differentiate(term, wrt): + r.set(entry, newterm, lrepr.at(entry, term)) + + return r, state + + def __repr__(self) -> str: + return f"(∂_{self.wrt} {self.inner})" + +@unary_operator class Diff(Expr, PolyRingAlgebra): """ Total differentiation. """ - # ┏┳┓┏━┓╺┳╸┏━┓╻╻ ╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓ # ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫ # ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ diff --git a/mdpoly/types.py b/mdpoly/types.py index 1132260..1605aa8 100644 --- a/mdpoly/types.py +++ b/mdpoly/types.py @@ -2,9 +2,10 @@ from __future__ import annotations from .constants import NUMERICS_EPS from .errors import InvalidShape +from .util import partition from itertools import filterfalse -from typing import Self, NamedTuple, Iterable, TYPE_CHECKING +from typing import Self, NamedTuple, Iterable, Optional, TYPE_CHECKING if TYPE_CHECKING: from .state import State @@ -161,8 +162,7 @@ class PolyIndex(tuple[PolyVarIndex]): @classmethod def product(cls, left: Self, right: Self) -> Self: - """ - Compute index of product. + """ Compute index of product. For example if left is the index of :math:`xy` and right is the index of :math:`y^2` this functions returns the index of :math:`xy^3`. @@ -181,3 +181,31 @@ class PolyIndex(tuple[PolyVarIndex]): result[idx] += power return cls.sort(cls.from_dict(result)) + + @classmethod + def differentiate(cls, index: Self, wrt: int) -> Optional[Self]: + """ Compute the index of differentiation + + For example if the index is :math:`xy^2` and ``wrt`` is the index of + :math:`y` this function returns the index of :math:`xy`. In particular + this function takes care of the edge cases of differentiating a + constant and differentiating linear terms. Specifically, if the index + is a constant ``None`` is returned. + """ + + if cls.is_constant(index): + return None + + def is_wrt_var(idx: PolyVarIndex) -> bool: + return idx.var_idx == wrt + + others, with_wrt_var = partition(is_wrt_var, index) + with_wrt_var, *_ = with_wrt_var # Take first, as should be only one + + # Check if is linear term + if (with_wrt_var.power - 1) < NUMERICS_EPS: + return cls.sort(tuple(others) + cls.constant) + + # Decrease exponent + new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1)) + return cls.sort(tuple(others) + (new_idx,)) diff --git a/mdpoly/util.py b/mdpoly/util.py new file mode 100644 index 0000000..f1d5de2 --- /dev/null +++ b/mdpoly/util.py @@ -0,0 +1,12 @@ +from itertools import tee, filterfalse + +def partition(pred, iterable): + """Partition entries into false entries and true entries. + + If *pred* is slow, consider wrapping it with functools.lru_cache(). + """ + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + t1, t2 = tee(iterable) + return filterfalse(pred, t1), filter(pred, t2) + + |