diff options
-rw-r--r-- | mdpoly/operations/mul.py | 25 | ||||
-rw-r--r-- | mdpoly/operations/transpose.py | 2 |
2 files changed, 23 insertions, 4 deletions
diff --git a/mdpoly/operations/mul.py b/mdpoly/operations/mul.py index 1c5a629..04796a2 100644 --- a/mdpoly/operations/mul.py +++ b/mdpoly/operations/mul.py @@ -6,11 +6,13 @@ from itertools import product from dataclasses import dataclass from dataclassabc import dataclassabc +from ..abc import Expr from ..index import Shape from ..errors import AlgebraicError, InvalidShape from ..index import MatrixIndex, PolyIndex from . import BinaryOp, Reducible +from .transpose import MatTranspose if TYPE_CHECKING: from ..abc import ReprT @@ -107,7 +109,8 @@ class MatMul(BinaryOp): """ See :py:meth:`mdpoly.abc.Expr.shape`. """ if not self.left.shape.rows == self.right.shape.cols: raise AlgebraicError("Cannot perform matrix multiplication between " - f"{self.left} and {self.right} (shapes {self.left.shape} and {self.right.shape})") + f"{self.left} and {self.right} (shapes {self.left.shape} " + f"and {self.right.shape})") return Shape(self.left.shape.rows, self.right.shape.cols) @@ -121,8 +124,21 @@ class MatMul(BinaryOp): # Compute matrix product for row in range(self.left.shape.rows): for col in range(self.right.shape.cols): - for i in range(self.left.shape.cols): - raise NotImplementedError + # Entry of result + entry = MatrixIndex(row, col) + + for k in range(self.left.shape.cols): + lentry = MatrixIndex(row, k) + rentry = MatrixIndex(k, col) + + # Product of polynomials at lentry and rentry + for lterm, rterm in product(lrepr.terms(lentry), rrepr.terms(rentry)): + # Compute index of product + term = PolyIndex.product(lterm, rterm) + # Compute product + p = r.at(entry, term) + p += lrepr.at(lentry, term) * rrepr.at(rentry, term) + r.set(entry, term, p) return r, state @@ -148,6 +164,9 @@ class MatDotProd(BinaryOp, Reducible): return Shape.scalar() + def reduce(self) -> Expr: + return MatMul(MatTranspose(self.left), self.right) + # ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓┏━┓┏━┓╺┳┓╻ ╻┏━╸╺┳╸ # ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┛┣┳┛┃ ┃ ┃┃┃ ┃┃ ┃ diff --git a/mdpoly/operations/transpose.py b/mdpoly/operations/transpose.py index 0e4228b..c67c9d2 100644 --- a/mdpoly/operations/transpose.py +++ b/mdpoly/operations/transpose.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from dataclasses import dataclass -from ..expressions import UnaryOp +from . import UnaryOp if TYPE_CHECKING: from ..index import Shape |