aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-19 00:17:03 +0100
committerNao Pross <np@0hm.ch>2024-03-19 00:48:28 +0100
commit290f814d9803aee04eb26ee5b4a76a021825caa2 (patch)
tree8166337edb6be7c1c4283d354c24677c082c4cf4
parentAdd comments and improve docstrings (diff)
downloadmdpoly-290f814d9803aee04eb26ee5b4a76a021825caa2.tar.gz
mdpoly-290f814d9803aee04eb26ee5b4a76a021825caa2.zip
Implement matrix product
-rw-r--r--mdpoly/operations/mul.py25
-rw-r--r--mdpoly/operations/transpose.py2
2 files changed, 23 insertions, 4 deletions
diff --git a/mdpoly/operations/mul.py b/mdpoly/operations/mul.py
index 1c5a629..04796a2 100644
--- a/mdpoly/operations/mul.py
+++ b/mdpoly/operations/mul.py
@@ -6,11 +6,13 @@ from itertools import product
from dataclasses import dataclass
from dataclassabc import dataclassabc
+from ..abc import Expr
from ..index import Shape
from ..errors import AlgebraicError, InvalidShape
from ..index import MatrixIndex, PolyIndex
from . import BinaryOp, Reducible
+from .transpose import MatTranspose
if TYPE_CHECKING:
from ..abc import ReprT
@@ -107,7 +109,8 @@ class MatMul(BinaryOp):
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
if not self.left.shape.rows == self.right.shape.cols:
raise AlgebraicError("Cannot perform matrix multiplication between "
- f"{self.left} and {self.right} (shapes {self.left.shape} and {self.right.shape})")
+ f"{self.left} and {self.right} (shapes {self.left.shape} "
+ f"and {self.right.shape})")
return Shape(self.left.shape.rows, self.right.shape.cols)
@@ -121,8 +124,21 @@ class MatMul(BinaryOp):
# Compute matrix product
for row in range(self.left.shape.rows):
for col in range(self.right.shape.cols):
- for i in range(self.left.shape.cols):
- raise NotImplementedError
+ # Entry of result
+ entry = MatrixIndex(row, col)
+
+ for k in range(self.left.shape.cols):
+ lentry = MatrixIndex(row, k)
+ rentry = MatrixIndex(k, col)
+
+ # Product of polynomials at lentry and rentry
+ for lterm, rterm in product(lrepr.terms(lentry), rrepr.terms(rentry)):
+ # Compute index of product
+ term = PolyIndex.product(lterm, rterm)
+ # Compute product
+ p = r.at(entry, term)
+ p += lrepr.at(lentry, term) * rrepr.at(rentry, term)
+ r.set(entry, term, p)
return r, state
@@ -148,6 +164,9 @@ class MatDotProd(BinaryOp, Reducible):
return Shape.scalar()
+ def reduce(self) -> Expr:
+ return MatMul(MatTranspose(self.left), self.right)
+
# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓┏━┓┏━┓╺┳┓╻ ╻┏━╸╺┳╸
# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┛┣┳┛┃ ┃ ┃┃┃ ┃┃ ┃
diff --git a/mdpoly/operations/transpose.py b/mdpoly/operations/transpose.py
index 0e4228b..c67c9d2 100644
--- a/mdpoly/operations/transpose.py
+++ b/mdpoly/operations/transpose.py
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from dataclasses import dataclass
-from ..expressions import UnaryOp
+from . import UnaryOp
if TYPE_CHECKING:
from ..index import Shape