aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/algebra.py28
-rw-r--r--mdpoly/types.py21
2 files changed, 44 insertions, 5 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 7085b42..670b537 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -161,7 +161,8 @@ class Add(Expr, HasRepr):
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
if self.left.shape != self.right.shape:
- raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
+ raise InvalidShape(f"Cannot add {self.left} and {self.right} with "
+ f"shapes {self.left.shape} and {self.right.shape}.")
return self.left.shape
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
@@ -494,9 +495,23 @@ class MatScalarMul(ReducibleExpr, MatrixAlgebra):
raise InvalidShape(f"Either {self.left} or {self.right} must be a scalar.")
+
@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class MatInnerProd(Expr, MatrixAlgebra):
- """ Inner product. """
+class MatDotProd(Expr, MatrixAlgebra):
+ """ Dot product. """
+
+ @property
+ def shape(self) -> Shape:
+ if not self.left.shape.is_row():
+ raise AlgebraicError(f"Left operand {self.left} must be a row!")
+
+ if not self.right.shape.is_col():
+ raise AlgebraicError(f"Right operand {self.right} must be a column!")
+
+ if self.left.shape.cols != self.right.shape.rows:
+ raise AlgebraicError(f"Rows of {self.right} and columns {self.left} do not match!")
+
+ return Shape.scalar()
@binary_operator(MatrixAlgebra, MatrixAlgebra)
@@ -506,7 +521,12 @@ class MatMul(Expr, MatrixAlgebra):
@property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
- raise NotImplementedError
+ if not self.left.shape.rows == self.right.shape.cols:
+ raise AlgebraicError("Cannot perform matrix multiplication between "
+ f"{self.left} and {self.right} (shapes {self.left.shape} and {self.right.shape})")
+
+ return Shape(self.left.shape.rows, self.right.shape.cols)
+
def __repr__(self) -> str:
return f"({self.left} @ {self.right})"
diff --git a/mdpoly/types.py b/mdpoly/types.py
index 46c5882..9e4ecb7 100644
--- a/mdpoly/types.py
+++ b/mdpoly/types.py
@@ -37,12 +37,31 @@ class Shape(NamedTuple):
raise InvalidShape("Row vector must have dimension >= 1")
return cls(1, n)
+ def is_row(self) -> bool:
+ return self.cols == 1
+
@classmethod
- def column(cls, n: int) -> Self:
+ def col(cls, n: int) -> Self:
if n <= 0:
raise InvalidShape("Column vector must have dimension >= 1")
return cls(n, 1)
+ def is_col(self) -> bool:
+ return self.rows == 1
+
+ # --- aliases / shorthands ---
+
+ @property
+ def columns(self):
+ return self.cols
+
+ @classmethod
+ def column(cls, n: int) -> Self:
+ return cls.col(n)
+
+ def is_column(self) -> bool:
+ return self.is_col()
+
class MatrixIndex(NamedTuple):
""" Tuple to index an element of a matrix or vector. """