aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/abc.py43
1 files changed, 42 insertions, 1 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index 9977d65..a2d74c8 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -14,6 +14,16 @@ class Leaf(Protocol):
name: str
shape: Shape
+ def children(self) -> tuple:
+ """ Return an empty tuple. This is here to simplify recursive code on
+ expressions. """
+ return tuple()
+
+ def __iter__(self):
+ """ Return an empty iterator. This is here to simplify recursive code
+ on expressions. """
+ yield from iter(self.children())
+
@runtime_checkable
class Expr(Protocol):
@@ -42,9 +52,40 @@ class Expr(Protocol):
yield self.right
else:
yield from self.right.leaves()
+
+ def replace(self, leaf: Leaf, new_leaf: Leaf) -> None:
+ """ Create a new expression wherein all leaves equal to ``leaf`` are
+ replaced with ``new_leaf``.
+
+ This can be used to convert variables into parameters or constants, and
+ vice versa. For example:
+
+ .. code:: py
+
+ x, y = Variable.from_names("x, y")
+ poly_2d = x ** 2 + 2 * x * y + y ** 2
+
+ # Suppose we want to consider only x as variable to reduce poly_2d
+ # to a one-dimensional polynomial
+ yp = Parameter("y")
+ poly_1d = poly_2d.replace(y, yp)
+ """
+ def replace_leaves(node):
+ if isinstance(node, Leaf):
+ if node == leaf:
+ return new_leaf
+ else:
+ return node
+
+ left = replace_leaves(node.left)
+ right = replace_leaves(node.right)
+ return node.__class__(left, right)
+
+ return replace_leaves(self)
+
def __iter__(self) -> Sequence[Self | Leaf]:
- return (self.left, self.right)
+ yield from self.children()
class Repr(Protocol):