aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/abc.py25
1 files changed, 14 insertions, 11 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index a2d74c8..1ceac87 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -53,9 +53,9 @@ class Expr(Protocol):
else:
yield from self.right.leaves()
- def replace(self, leaf: Leaf, new_leaf: Leaf) -> None:
- """ Create a new expression wherein all leaves equal to ``leaf`` are
- replaced with ``new_leaf``.
+ def replace(self, old: Self, new: Self) -> Self:
+ """ Create a new expression wherein all expression equal to ``old`` are
+ replaced with ``new``.
This can be used to convert variables into parameters or constants, and
vice versa. For example:
@@ -69,19 +69,22 @@ class Expr(Protocol):
# to a one-dimensional polynomial
yp = Parameter("y")
poly_1d = poly_2d.replace(y, yp)
+
+ It can also be used for more advanced replacements since ``old`` and
+ ``new`` are free to be any expression.
"""
- def replace_leaves(node):
+ def replace_all(node):
+ if node == old:
+ return new
+
if isinstance(node, Leaf):
- if node == leaf:
- return new_leaf
- else:
- return node
+ return node
- left = replace_leaves(node.left)
- right = replace_leaves(node.right)
+ left = replace_all(node.left)
+ right = replace_all(node.right)
return node.__class__(left, right)
- return replace_leaves(self)
+ return replace_all(self)
def __iter__(self) -> Sequence[Self | Leaf]: