aboutsummaryrefslogtreecommitdiffstats
path: root/mdpoly/operations/mul.py
diff options
context:
space:
mode:
Diffstat (limited to 'mdpoly/operations/mul.py')
-rw-r--r--mdpoly/operations/mul.py186
1 files changed, 186 insertions, 0 deletions
diff --git a/mdpoly/operations/mul.py b/mdpoly/operations/mul.py
new file mode 100644
index 0000000..a2a3cc4
--- /dev/null
+++ b/mdpoly/operations/mul.py
@@ -0,0 +1,186 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+from typing import Type
+from itertools import product
+from dataclasses import dataclass
+from dataclassabc import dataclassabc
+
+from ..index import Shape
+from ..errors import AlgebraicError, InvalidShape
+from ..index import MatrixIndex, PolyIndex
+
+from ..expressions import BinaryOp, Reducible
+from ..expressions.matrix import MatrixExpr
+from ..expression.poly import PolyRingExpr
+
+if TYPE_CHECKING:
+ from ..abc import ReprT
+ from ..state import State
+
+
+# ┏┳┓┏━┓╺┳╸┏━┓╻╻ ╻ ┏━┓┏━┓┏━┓╺┳┓╻ ╻┏━╸╺┳╸┏━┓
+# ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┣━┛┣┳┛┃ ┃ ┃┃┃ ┃┃ ┃ ┗━┓
+# ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ╹ ╹┗╸┗━┛╺┻┛┗━┛┗━╸ ╹ ┗━┛
+
+
+@dataclassabc
+class MatElemMul(BinaryOp, MatrixExpr):
+ """ Elementwise Matrix Multiplication. """
+
+ @property
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if not self.left.shape == self.right.shape:
+ raise AlgebraicError("Cannot perform element-wise multiplication of matrices "
+ f"{self.left} and {self.right} with different shapes, "
+ f"{self.left.shape} and {self.right.shape}")
+ return self.left.shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """
+ r = repr_type(self.shape)
+
+ lrepr, state = self.left.to_repr(repr_type, state)
+ rrepr, state = self.right.to_repr(repr_type, state)
+
+ # Non zero entries are the intersection since if either is zero the
+ # result is zero
+ nonzero_entries = set(lrepr.entries()) & set(rrepr.entries())
+ for entry in nonzero_entries:
+ # Compute polynomial product between non-zero entries
+ for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
+ # Compute where the results should go
+ term = PolyIndex.product(lterm, rterm)
+
+ # Compute product
+ p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
+ r.set(entry, term, p)
+
+ return r, state
+
+ def __str__(self) -> str:
+ return f"({self.left} .* {self.right})"
+
+
+@dataclassabc
+class MatScalarMul(BinaryOp, MatrixExpr):
+ """ Matrix-Scalar Multiplication. Assumes scalar is on the left and matrix
+ on the right. """
+
+ @property
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if not self.left.shape == Shape.scalar():
+ raise InvalidShape(f"Matrix-scalar product assumes that left argumet {self.left} "
+ f"but it has shape {self.left.shape}")
+
+ return self.right.shape
+
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """
+ r = repr_type(self.shape)
+
+ scalar_repr, state = self.left.to_repr(repr_type, state)
+ mat_repr, state = self.right.to_repr(repr_type, state)
+
+ for entry in mat_repr.entries():
+ scalar_terms = scalar_repr.terms(MatrixIndex.scalar())
+ mat_terms = mat_repr.terms(entry)
+ for scalar_term, mat_term in product(scalar_terms, mat_terms):
+ term = PolyIndex.product(scalar_term, mat_term)
+
+ p = r.at(entry, term) + scalar_repr.at(entry, scalar_term) + mat_repr.at(entry, mat_term)
+ r.set(entry, term, p)
+
+ return r, state
+
+ def __str__(self) -> str:
+ return f"({self.left} * {self.right})"
+
+
+@dataclass(eq=False)
+class MatMul(BinaryOp, MatrixExpr):
+ """ Matrix Multiplication. """
+
+ @property
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if not self.left.shape.rows == self.right.shape.cols:
+ raise AlgebraicError("Cannot perform matrix multiplication between "
+ f"{self.left} and {self.right} (shapes {self.left.shape} and {self.right.shape})")
+
+ return Shape(self.left.shape.rows, self.right.shape.cols)
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """
+ r = repr_type(self.shape)
+
+ lrepr, state = self.left.to_repr(repr_type, state)
+ rrepr, state = self.right.to_repr(repr_type, state)
+
+ # Compute matrix product
+ for row in range(self.left.shape.rows):
+ for col in range(self.right.shape.cols):
+ for i in range(self.left.shape.cols):
+ raise NotImplementedError
+
+ return r, state
+
+
+ def __str__(self) -> str:
+ return f"({self.left} @ {self.right})"
+
+
+@dataclass(eq=False)
+class MatDotProd(BinaryOp, MatrixExpr, Reducible):
+ """ Dot product. """
+
+ @property
+ def shape(self) -> Shape:
+ if not self.left.shape.is_row():
+ raise AlgebraicError(f"Left operand {self.left} must be a row!")
+
+ if not self.right.shape.is_col():
+ raise AlgebraicError(f"Right operand {self.right} must be a column!")
+
+ if self.left.shape.cols != self.right.shape.rows:
+ raise AlgebraicError(f"Rows of {self.right} and columns {self.left} do not match!")
+
+ return Shape.scalar()
+
+
+# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓┏━┓┏━┓╺┳┓╻ ╻┏━╸╺┳╸
+# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┛┣┳┛┃ ┃ ┃┃┃ ┃┃ ┃
+# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗╸┗━┛╺┻┛┗━┛┗━╸ ╹
+
+
+@dataclass(eq=False)
+class PolyMul(BinaryOp, PolyRingExpr):
+ """ Multiplication operator between scalar polynomials. """
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """
+ r = repr_type(self.shape)
+
+ lrepr, state = self.left.to_repr(repr_type, state)
+ rrepr, state = self.right.to_repr(repr_type, state)
+
+ # Non zero entries are the intersection since if either is zero the
+ # result is zero
+ nonzero_entries = set(lrepr.entries()) & set(rrepr.entries())
+ for entry in nonzero_entries:
+ # Compute polynomial product between non-zero entries
+ for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
+ # Compute where the results should go
+ term = PolyIndex.product(lterm, rterm)
+
+ # Compute product
+ p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
+ r.set(entry, term, p)
+
+ return r, state
+
+ def __str__(self) -> str:
+ return f"({self.left} * {self.right})"