From 6287ccff3b75a55cff121cd256e322a66ba1c449 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Tue, 5 Mar 2024 00:45:05 +0100 Subject: Implement more shape checks --- mdpoly/algebra.py | 28 ++++++++++++++++++++++++---- mdpoly/types.py | 21 ++++++++++++++++++++- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 7085b42..670b537 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -161,7 +161,8 @@ class Add(Expr, HasRepr): def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ if self.left.shape != self.right.shape: - raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") + raise InvalidShape(f"Cannot add {self.left} and {self.right} with " + f"shapes {self.left.shape} and {self.right.shape}.") return self.left.shape def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: @@ -494,9 +495,23 @@ class MatScalarMul(ReducibleExpr, MatrixAlgebra): raise InvalidShape(f"Either {self.left} or {self.right} must be a scalar.") + @binary_operator(MatrixAlgebra, MatrixAlgebra) -class MatInnerProd(Expr, MatrixAlgebra): - """ Inner product. """ +class MatDotProd(Expr, MatrixAlgebra): + """ Dot product. """ + + @property + def shape(self) -> Shape: + if not self.left.shape.is_row(): + raise AlgebraicError(f"Left operand {self.left} must be a row!") + + if not self.right.shape.is_col(): + raise AlgebraicError(f"Right operand {self.right} must be a column!") + + if self.left.shape.cols != self.right.shape.rows: + raise AlgebraicError(f"Rows of {self.right} and columns {self.left} do not match!") + + return Shape.scalar() @binary_operator(MatrixAlgebra, MatrixAlgebra) @@ -506,7 +521,12 @@ class MatMul(Expr, MatrixAlgebra): @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ - raise NotImplementedError + if not self.left.shape.rows == self.right.shape.cols: + raise AlgebraicError("Cannot perform matrix multiplication between " + f"{self.left} and {self.right} (shapes {self.left.shape} and {self.right.shape})") + + return Shape(self.left.shape.rows, self.right.shape.cols) + def __repr__(self) -> str: return f"({self.left} @ {self.right})" diff --git a/mdpoly/types.py b/mdpoly/types.py index 46c5882..9e4ecb7 100644 --- a/mdpoly/types.py +++ b/mdpoly/types.py @@ -37,12 +37,31 @@ class Shape(NamedTuple): raise InvalidShape("Row vector must have dimension >= 1") return cls(1, n) + def is_row(self) -> bool: + return self.cols == 1 + @classmethod - def column(cls, n: int) -> Self: + def col(cls, n: int) -> Self: if n <= 0: raise InvalidShape("Column vector must have dimension >= 1") return cls(n, 1) + def is_col(self) -> bool: + return self.rows == 1 + + # --- aliases / shorthands --- + + @property + def columns(self): + return self.cols + + @classmethod + def column(cls, n: int) -> Self: + return cls.col(n) + + def is_column(self) -> bool: + return self.is_col() + class MatrixIndex(NamedTuple): """ Tuple to index an element of a matrix or vector. """ -- cgit v1.2.1