From c43467a8974812e0a35e4dd30196996e06070639 Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Wed, 20 Mar 2024 11:12:52 +0100
Subject: Re-implement __matmul__

---
 mdpoly/expressions.py | 20 ++++++++++++++------
 1 file changed, 14 insertions(+), 6 deletions(-)

diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
index 2824332..edaf1aa 100644
--- a/mdpoly/expressions.py
+++ b/mdpoly/expressions.py
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
 
 from dataclasses import dataclass
 from functools import wraps
-from typing import Type, Callable, Any, Self
+from typing import Type, Callable, Any, Self, Sequence
 
 from .abc import Expr, Var, Repr
 from .index import Number
@@ -13,7 +13,7 @@ from .representations import SparseRepr
 
 from .operations import WithTraceback
 from .operations.add import MatAdd, MatSub
-from .operations.mul import MatElemMul
+from .operations.mul import MatElemMul, MatMul
 from .operations.exp import PolyExp
 from .operations.derivative import PolyPartialDiff
 
@@ -186,11 +186,19 @@ class WithOps:
             # TODO: what if it is a matrix? Elementwise power?
             return PolyExp(left, right)
 
-    def __matmul__(self, other: Any) -> Self:
-        raise NotImplementedError
+    @wrap_result
+    def __matmul__(self, other: WithOps) -> Expr:
+        other = self._ensure_is_withops(other)
+        with self as left, other as right:
+            return MatMul(left, right)
 
-    def __rmatmul__(self, other: Any) -> Self:
-        raise NotImplementedError
+    @wrap_result
+    def __rmatmul__(self, other: Number | Sequence[Sequence[Number]]) -> Expr:
+        # FIXME: _ensure_is_withops does not handle matrices,
+        # must fix to take eg row major Sequence[Sequence[Number]] and / or np.NDArray
+        other = self._ensure_is_withops(other)
+        with other as left, self as right:
+            return MatMul(left, right)
 
     def __truediv__(self, other: Any) -> Self:
         raise NotImplementedError
-- 
cgit v1.2.1