diff options
author | Nao Pross <np@0hm.ch> | 2024-03-20 11:12:52 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-20 11:12:52 +0100 |
commit | c43467a8974812e0a35e4dd30196996e06070639 (patch) | |
tree | 044f40736c92352a26fd5b08e67d1fbc0507a1f6 | |
parent | Fix MatSub, separate leaves (diff) | |
download | mdpoly-c43467a8974812e0a35e4dd30196996e06070639.tar.gz mdpoly-c43467a8974812e0a35e4dd30196996e06070639.zip |
Re-implement __matmul__
Diffstat (limited to '')
-rw-r--r-- | mdpoly/expressions.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py index 2824332..edaf1aa 100644 --- a/mdpoly/expressions.py +++ b/mdpoly/expressions.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from dataclasses import dataclass from functools import wraps -from typing import Type, Callable, Any, Self +from typing import Type, Callable, Any, Self, Sequence from .abc import Expr, Var, Repr from .index import Number @@ -13,7 +13,7 @@ from .representations import SparseRepr from .operations import WithTraceback from .operations.add import MatAdd, MatSub -from .operations.mul import MatElemMul +from .operations.mul import MatElemMul, MatMul from .operations.exp import PolyExp from .operations.derivative import PolyPartialDiff @@ -186,11 +186,19 @@ class WithOps: # TODO: what if it is a matrix? Elementwise power? return PolyExp(left, right) - def __matmul__(self, other: Any) -> Self: - raise NotImplementedError + @wrap_result + def __matmul__(self, other: WithOps) -> Expr: + other = self._ensure_is_withops(other) + with self as left, other as right: + return MatMul(left, right) - def __rmatmul__(self, other: Any) -> Self: - raise NotImplementedError + @wrap_result + def __rmatmul__(self, other: Number | Sequence[Sequence[Number]]) -> Expr: + # FIXME: _ensure_is_withops does not handle matrices, + # must fix to take eg row major Sequence[Sequence[Number]] and / or np.NDArray + other = self._ensure_is_withops(other) + with other as left, self as right: + return MatMul(left, right) def __truediv__(self, other: Any) -> Self: raise NotImplementedError |