diff options
author | Nao Pross <np@0hm.ch> | 2024-03-03 18:28:23 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-03 18:28:23 +0100 |
commit | f7f7c927627c36bd217fa4dad8ba9d05c6f050af (patch) | |
tree | ce1589e1e226c469f6163b9559299ed25201098d | |
parent | Add method Expr.replace() to replace leaves (diff) | |
download | mdpoly-f7f7c927627c36bd217fa4dad8ba9d05c6f050af.tar.gz mdpoly-f7f7c927627c36bd217fa4dad8ba9d05c6f050af.zip |
Extend Expr.replace() to work with any expression
-rw-r--r-- | mdpoly/abc.py | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py index a2d74c8..1ceac87 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -53,9 +53,9 @@ class Expr(Protocol): else: yield from self.right.leaves() - def replace(self, leaf: Leaf, new_leaf: Leaf) -> None: - """ Create a new expression wherein all leaves equal to ``leaf`` are - replaced with ``new_leaf``. + def replace(self, old: Self, new: Self) -> Self: + """ Create a new expression wherein all expression equal to ``old`` are + replaced with ``new``. This can be used to convert variables into parameters or constants, and vice versa. For example: @@ -69,19 +69,22 @@ class Expr(Protocol): # to a one-dimensional polynomial yp = Parameter("y") poly_1d = poly_2d.replace(y, yp) + + It can also be used for more advanced replacements since ``old`` and + ``new`` are free to be any expression. """ - def replace_leaves(node): + def replace_all(node): + if node == old: + return new + if isinstance(node, Leaf): - if node == leaf: - return new_leaf - else: - return node + return node - left = replace_leaves(node.left) - right = replace_leaves(node.right) + left = replace_all(node.left) + right = replace_all(node.right) return node.__class__(left, right) - return replace_leaves(self) + return replace_all(self) def __iter__(self) -> Sequence[Self | Leaf]: |