summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-04 17:54:27 +0200
committerNao Pross <np@0hm.ch>2024-05-04 18:01:42 +0200
commitb4ffdbafa8d61aae2806e286ebdde63821f23092 (patch)
tree48909c4a21cdd4a999cc5e922cef86fca1f4fb4d
parentFix polymatrix.to_sympy() (diff)
downloadpolymatrix-b4ffdbafa8d61aae2806e286ebdde63821f23092.tar.gz
polymatrix-b4ffdbafa8d61aae2806e286ebdde63821f23092.zip
Fix Expression._binary, avoid creating unnecessary nodes
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/expression.py20
1 files changed, 7 insertions, 13 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index edb0830..3c52297 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -117,31 +117,25 @@ class Expression(ExpressionBaseMixin, ABC):
def _binary(op, left, right):
stack = get_stack_lines()
- if isinstance(left, Expression):
- right = polymatrix.expression.init.init_from_expr_or_none(right)
+ if isinstance(left, Expression) and isinstance(right, Expression):
+ return left.copy(underlying=op(left.underlying, right.underlying, stack))
- # delegate to upper level
+ elif isinstance(left, Expression):
+ right = polymatrix.expression.init.init_from_expr_or_none(right)
- # NP: what is upper level? Class inherits from ABC and base class,
- # NP: neither has binary operators, also NotImplemented IIRC is only for __lt__
- # NP: and other comparison methods and not for overloadings like __add__
if right is None:
return NotImplemented
- return left.copy(
- underlying=op(left, right, stack),
- )
+ return left.copy(underlying=op(left.underlying, right, stack))
+ # else right is an Expression
else:
left = polymatrix.expression.init.init_from_expr_or_none(left)
- # delegate to upper level
if left is None:
return NotImplemented
- return right.copy(
- underlying=op(left, right, stack),
- )
+ return right.copy( underlying=op(left, right.underlying, stack))
def cache(self) -> "Expression":
return self.copy(