From f7f7c927627c36bd217fa4dad8ba9d05c6f050af Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sun, 3 Mar 2024 18:28:23 +0100 Subject: Extend Expr.replace() to work with any expression --- mdpoly/abc.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/mdpoly/abc.py b/mdpoly/abc.py index a2d74c8..1ceac87 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -53,9 +53,9 @@ class Expr(Protocol): else: yield from self.right.leaves() - def replace(self, leaf: Leaf, new_leaf: Leaf) -> None: - """ Create a new expression wherein all leaves equal to ``leaf`` are - replaced with ``new_leaf``. + def replace(self, old: Self, new: Self) -> Self: + """ Create a new expression wherein all expression equal to ``old`` are + replaced with ``new``. This can be used to convert variables into parameters or constants, and vice versa. For example: @@ -69,19 +69,22 @@ class Expr(Protocol): # to a one-dimensional polynomial yp = Parameter("y") poly_1d = poly_2d.replace(y, yp) + + It can also be used for more advanced replacements since ``old`` and + ``new`` are free to be any expression. """ - def replace_leaves(node): + def replace_all(node): + if node == old: + return new + if isinstance(node, Leaf): - if node == leaf: - return new_leaf - else: - return node + return node - left = replace_leaves(node.left) - right = replace_leaves(node.right) + left = replace_all(node.left) + right = replace_all(node.right) return node.__class__(left, right) - return replace_leaves(self) + return replace_all(self) def __iter__(self) -> Sequence[Self | Leaf]: -- cgit v1.2.1