diff options
Diffstat (limited to '')
-rw-r--r-- | mdpoly/expressions.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py index 2824332..edaf1aa 100644 --- a/mdpoly/expressions.py +++ b/mdpoly/expressions.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from dataclasses import dataclass from functools import wraps -from typing import Type, Callable, Any, Self +from typing import Type, Callable, Any, Self, Sequence from .abc import Expr, Var, Repr from .index import Number @@ -13,7 +13,7 @@ from .representations import SparseRepr from .operations import WithTraceback from .operations.add import MatAdd, MatSub -from .operations.mul import MatElemMul +from .operations.mul import MatElemMul, MatMul from .operations.exp import PolyExp from .operations.derivative import PolyPartialDiff @@ -186,11 +186,19 @@ class WithOps: # TODO: what if it is a matrix? Elementwise power? return PolyExp(left, right) - def __matmul__(self, other: Any) -> Self: - raise NotImplementedError + @wrap_result + def __matmul__(self, other: WithOps) -> Expr: + other = self._ensure_is_withops(other) + with self as left, other as right: + return MatMul(left, right) - def __rmatmul__(self, other: Any) -> Self: - raise NotImplementedError + @wrap_result + def __rmatmul__(self, other: Number | Sequence[Sequence[Number]]) -> Expr: + # FIXME: _ensure_is_withops does not handle matrices, + # must fix to take eg row major Sequence[Sequence[Number]] and / or np.NDArray + other = self._ensure_is_withops(other) + with other as left, self as right: + return MatMul(left, right) def __truediv__(self, other: Any) -> Self: raise NotImplementedError |