From a6ba5e281e7eaf9850920f01de3274622227031f Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sun, 10 Mar 2024 10:11:54 +0100 Subject: Fix PolyPartialDiff.replace --- mdpoly/algebra.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 04a61ad..1e4b775 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -6,7 +6,7 @@ from .errors import AlgebraicError, InvalidShape, MissingParameters from .state import State from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number -from typing import cast, Sequence, Iterable, Type, TypeVar +from typing import cast, Sequence, Iterable, Type, TypeVar, Self from functools import reduce from itertools import product, chain, combinations_with_replacement from abc import abstractmethod @@ -341,7 +341,7 @@ class PolyExp(BinaryOp, PolyRingExpr, Reducible): class PolyPartialDiff(UnaryOp, PolyRingExpr): """ Partial differentiation of scalar polynomials. """ - def __init__(self, inner: Expr, with_respect_to: Var): + def __init__(self, inner: Expr, with_respect_to: PolyVar): UnaryOp.__init__(self, inner) self.wrt = with_respect_to @@ -367,6 +367,18 @@ class PolyPartialDiff(UnaryOp, PolyRingExpr): def __repr__(self) -> str: return f"(∂_{self.wrt} {self.inner})" + def replace(self, old: PolyRingExpr, new: PolyRingExpr) -> Self: + """ Overloads :py:meth:`mdpoly.abc.Expr.replace` """ + if self.wrt == old: + if not isinstance(new, PolyVar): + # FIXME: implement chain rule + raise AlgebraicError(f"Cannot take a derivative with respect to {new}.") + + self.wrt = cast(PolyVar, new) + + return cast(Self, Expr.replace(self, old, new)) + + # ┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓ # ┣┳┛┣━┫ ┃ ┃┃ ┃┃┗┫┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫ -- cgit v1.2.1