aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-20 11:12:52 +0100
committerNao Pross <np@0hm.ch>2024-03-20 11:12:52 +0100
commitc43467a8974812e0a35e4dd30196996e06070639 (patch)
tree044f40736c92352a26fd5b08e67d1fbc0507a1f6
parentFix MatSub, separate leaves (diff)
downloadmdpoly-c43467a8974812e0a35e4dd30196996e06070639.tar.gz
mdpoly-c43467a8974812e0a35e4dd30196996e06070639.zip
Re-implement __matmul__
-rw-r--r--mdpoly/expressions.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
index 2824332..edaf1aa 100644
--- a/mdpoly/expressions.py
+++ b/mdpoly/expressions.py
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from dataclasses import dataclass
from functools import wraps
-from typing import Type, Callable, Any, Self
+from typing import Type, Callable, Any, Self, Sequence
from .abc import Expr, Var, Repr
from .index import Number
@@ -13,7 +13,7 @@ from .representations import SparseRepr
from .operations import WithTraceback
from .operations.add import MatAdd, MatSub
-from .operations.mul import MatElemMul
+from .operations.mul import MatElemMul, MatMul
from .operations.exp import PolyExp
from .operations.derivative import PolyPartialDiff
@@ -186,11 +186,19 @@ class WithOps:
# TODO: what if it is a matrix? Elementwise power?
return PolyExp(left, right)
- def __matmul__(self, other: Any) -> Self:
- raise NotImplementedError
+ @wrap_result
+ def __matmul__(self, other: WithOps) -> Expr:
+ other = self._ensure_is_withops(other)
+ with self as left, other as right:
+ return MatMul(left, right)
- def __rmatmul__(self, other: Any) -> Self:
- raise NotImplementedError
+ @wrap_result
+ def __rmatmul__(self, other: Number | Sequence[Sequence[Number]]) -> Expr:
+ # FIXME: _ensure_is_withops does not handle matrices,
+ # must fix to take eg row major Sequence[Sequence[Number]] and / or np.NDArray
+ other = self._ensure_is_withops(other)
+ with other as left, self as right:
+ return MatMul(left, right)
def __truediv__(self, other: Any) -> Self:
raise NotImplementedError