diff options
-rw-r--r-- | mdpoly/abc.py | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py index a2d74c8..1ceac87 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -53,9 +53,9 @@ class Expr(Protocol): else: yield from self.right.leaves() - def replace(self, leaf: Leaf, new_leaf: Leaf) -> None: - """ Create a new expression wherein all leaves equal to ``leaf`` are - replaced with ``new_leaf``. + def replace(self, old: Self, new: Self) -> Self: + """ Create a new expression wherein all expression equal to ``old`` are + replaced with ``new``. This can be used to convert variables into parameters or constants, and vice versa. For example: @@ -69,19 +69,22 @@ class Expr(Protocol): # to a one-dimensional polynomial yp = Parameter("y") poly_1d = poly_2d.replace(y, yp) + + It can also be used for more advanced replacements since ``old`` and + ``new`` are free to be any expression. """ - def replace_leaves(node): + def replace_all(node): + if node == old: + return new + if isinstance(node, Leaf): - if node == leaf: - return new_leaf - else: - return node + return node - left = replace_leaves(node.left) - right = replace_leaves(node.right) + left = replace_all(node.left) + right = replace_all(node.right) return node.__class__(left, right) - return replace_leaves(self) + return replace_all(self) def __iter__(self) -> Sequence[Self | Leaf]: |