diff options
Diffstat (limited to 'mdpoly/operations/mul.py')
-rw-r--r-- | mdpoly/operations/mul.py | 186 |
1 files changed, 186 insertions, 0 deletions
diff --git a/mdpoly/operations/mul.py b/mdpoly/operations/mul.py new file mode 100644 index 0000000..a2a3cc4 --- /dev/null +++ b/mdpoly/operations/mul.py @@ -0,0 +1,186 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from typing import Type +from itertools import product +from dataclasses import dataclass +from dataclassabc import dataclassabc + +from ..index import Shape +from ..errors import AlgebraicError, InvalidShape +from ..index import MatrixIndex, PolyIndex + +from ..expressions import BinaryOp, Reducible +from ..expressions.matrix import MatrixExpr +from ..expression.poly import PolyRingExpr + +if TYPE_CHECKING: + from ..abc import ReprT + from ..state import State + + +# ┏┳┓┏━┓╺┳╸┏━┓╻╻ ╻ ┏━┓┏━┓┏━┓╺┳┓╻ ╻┏━╸╺┳╸┏━┓ +# ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┣━┛┣┳┛┃ ┃ ┃┃┃ ┃┃ ┃ ┗━┓ +# ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ╹ ╹┗╸┗━┛╺┻┛┗━┛┗━╸ ╹ ┗━┛ + + +@dataclassabc +class MatElemMul(BinaryOp, MatrixExpr): + """ Elementwise Matrix Multiplication. """ + + @property + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + if not self.left.shape == self.right.shape: + raise AlgebraicError("Cannot perform element-wise multiplication of matrices " + f"{self.left} and {self.right} with different shapes, " + f"{self.left.shape} and {self.right.shape}") + return self.left.shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ + r = repr_type(self.shape) + + lrepr, state = self.left.to_repr(repr_type, state) + rrepr, state = self.right.to_repr(repr_type, state) + + # Non zero entries are the intersection since if either is zero the + # result is zero + nonzero_entries = set(lrepr.entries()) & set(rrepr.entries()) + for entry in nonzero_entries: + # Compute polynomial product between non-zero entries + for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)): + # Compute where the results should go + term = PolyIndex.product(lterm, rterm) + + # Compute product + p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm) + r.set(entry, term, p) + + return r, state + + def __str__(self) -> str: + return f"({self.left} .* {self.right})" + + +@dataclassabc +class MatScalarMul(BinaryOp, MatrixExpr): + """ Matrix-Scalar Multiplication. Assumes scalar is on the left and matrix + on the right. """ + + @property + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + if not self.left.shape == Shape.scalar(): + raise InvalidShape(f"Matrix-scalar product assumes that left argumet {self.left} " + f"but it has shape {self.left.shape}") + + return self.right.shape + + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ + r = repr_type(self.shape) + + scalar_repr, state = self.left.to_repr(repr_type, state) + mat_repr, state = self.right.to_repr(repr_type, state) + + for entry in mat_repr.entries(): + scalar_terms = scalar_repr.terms(MatrixIndex.scalar()) + mat_terms = mat_repr.terms(entry) + for scalar_term, mat_term in product(scalar_terms, mat_terms): + term = PolyIndex.product(scalar_term, mat_term) + + p = r.at(entry, term) + scalar_repr.at(entry, scalar_term) + mat_repr.at(entry, mat_term) + r.set(entry, term, p) + + return r, state + + def __str__(self) -> str: + return f"({self.left} * {self.right})" + + +@dataclass(eq=False) +class MatMul(BinaryOp, MatrixExpr): + """ Matrix Multiplication. """ + + @property + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + if not self.left.shape.rows == self.right.shape.cols: + raise AlgebraicError("Cannot perform matrix multiplication between " + f"{self.left} and {self.right} (shapes {self.left.shape} and {self.right.shape})") + + return Shape(self.left.shape.rows, self.right.shape.cols) + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ + r = repr_type(self.shape) + + lrepr, state = self.left.to_repr(repr_type, state) + rrepr, state = self.right.to_repr(repr_type, state) + + # Compute matrix product + for row in range(self.left.shape.rows): + for col in range(self.right.shape.cols): + for i in range(self.left.shape.cols): + raise NotImplementedError + + return r, state + + + def __str__(self) -> str: + return f"({self.left} @ {self.right})" + + +@dataclass(eq=False) +class MatDotProd(BinaryOp, MatrixExpr, Reducible): + """ Dot product. """ + + @property + def shape(self) -> Shape: + if not self.left.shape.is_row(): + raise AlgebraicError(f"Left operand {self.left} must be a row!") + + if not self.right.shape.is_col(): + raise AlgebraicError(f"Right operand {self.right} must be a column!") + + if self.left.shape.cols != self.right.shape.rows: + raise AlgebraicError(f"Rows of {self.right} and columns {self.left} do not match!") + + return Shape.scalar() + + +# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓┏━┓┏━┓╺┳┓╻ ╻┏━╸╺┳╸ +# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┛┣┳┛┃ ┃ ┃┃┃ ┃┃ ┃ +# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗╸┗━┛╺┻┛┗━┛┗━╸ ╹ + + +@dataclass(eq=False) +class PolyMul(BinaryOp, PolyRingExpr): + """ Multiplication operator between scalar polynomials. """ + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ + r = repr_type(self.shape) + + lrepr, state = self.left.to_repr(repr_type, state) + rrepr, state = self.right.to_repr(repr_type, state) + + # Non zero entries are the intersection since if either is zero the + # result is zero + nonzero_entries = set(lrepr.entries()) & set(rrepr.entries()) + for entry in nonzero_entries: + # Compute polynomial product between non-zero entries + for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)): + # Compute where the results should go + term = PolyIndex.product(lterm, rterm) + + # Compute product + p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm) + r.set(entry, term, p) + + return r, state + + def __str__(self) -> str: + return f"({self.left} * {self.right})" |