diff options
Diffstat (limited to '')
-rw-r--r-- | mdpoly/abc.py | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 9977d65..a2d74c8 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -14,6 +14,16 @@ class Leaf(Protocol): name: str shape: Shape + def children(self) -> tuple: + """ Return an empty tuple. This is here to simplify recursive code on + expressions. """ + return tuple() + + def __iter__(self): + """ Return an empty iterator. This is here to simplify recursive code + on expressions. """ + yield from iter(self.children()) + @runtime_checkable class Expr(Protocol): @@ -42,9 +52,40 @@ class Expr(Protocol): yield self.right else: yield from self.right.leaves() + + def replace(self, leaf: Leaf, new_leaf: Leaf) -> None: + """ Create a new expression wherein all leaves equal to ``leaf`` are + replaced with ``new_leaf``. + + This can be used to convert variables into parameters or constants, and + vice versa. For example: + + .. code:: py + + x, y = Variable.from_names("x, y") + poly_2d = x ** 2 + 2 * x * y + y ** 2 + + # Suppose we want to consider only x as variable to reduce poly_2d + # to a one-dimensional polynomial + yp = Parameter("y") + poly_1d = poly_2d.replace(y, yp) + """ + def replace_leaves(node): + if isinstance(node, Leaf): + if node == leaf: + return new_leaf + else: + return node + + left = replace_leaves(node.left) + right = replace_leaves(node.right) + return node.__class__(left, right) + + return replace_leaves(self) + def __iter__(self) -> Sequence[Self | Leaf]: - return (self.left, self.right) + yield from self.children() class Repr(Protocol): |