From 8c4caecc1cc1cfbacbd38a69a6aee4a87e317268 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 4 Mar 2024 01:30:28 +0100 Subject: Add shape check for MatScalarMul --- mdpoly/algebra.py | 49 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 46a82b6..ef3f7eb 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -8,7 +8,7 @@ from .types import Shape, MatrixIndex, PolyIndex from .representations import HasRepr from typing import Protocol, TypeVar, Any, runtime_checkable -from functools import wraps, reduce +from functools import wraps, reduce, cached_property from itertools import product from enum import Enum from abc import abstractmethod @@ -158,6 +158,7 @@ def unary_operator(inner_type: AlgebraicStructure): class Add(Expr, HasRepr): """ Generic addition (no type check) """ + @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ if self.left.shape != self.right.shape: @@ -189,6 +190,7 @@ class Add(Expr, HasRepr): class Sub(Expr, ReducibleExpr): """ Generic subtraction operator (no type check) """ + @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ if self.left.shape != self.right.shape: @@ -207,6 +209,7 @@ class Sub(Expr, ReducibleExpr): class Mul(Expr, HasRepr): """ Generic multiplication operator (no type check). """ + @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ if self.left.shape != self.right.shape: @@ -238,6 +241,7 @@ class Mul(Expr, HasRepr): class Exp(Expr, ReducibleExpr): """ Generic exponentiation (no type check). """ + @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ if self.left.shape != self.right.shape: @@ -349,6 +353,7 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra): def __init__(self, with_respect_to: Var): self.wrt = with_respect_to + @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ return self.inner.shape @@ -442,16 +447,16 @@ class MatrixAlgebra(AlgebraicStructure, Protocol): def transpose(self): """ Matrix transposition. """ - raise NotImplementedError + return MatTranspose(self) @property def T(self): """ Shorthand for :py:meth:`mdpoly.algebra.MatrixAlgebra.transpose`. """ return self.transpose() - def to_scalar(self, scalar_type: type): - """ Convert to a scalar expression. """ - raise NotImplementedError + # def to_scalar(self, scalar_type: type): + # """ Convert to a scalar expression. """ + # raise NotImplementedError @binary_operator(MatrixAlgebra, MatrixAlgebra) @@ -466,7 +471,7 @@ class MatSub(Sub, MatrixAlgebra): @binary_operator(MatrixAlgebra, MatrixAlgebra) class MatElemMul(Mul, MatrixAlgebra): - """ Elementwise Matrix Multiplication """ + """ Elementwise Matrix Multiplication. """ def __repr__(self) -> str: return f"({self.left} .* {self.right})" @@ -475,10 +480,38 @@ class MatElemMul(Mul, MatrixAlgebra): @binary_operator(MatrixAlgebra, MatrixAlgebra) class MatScalarMul(ReducibleExpr, MatrixAlgebra): """ Matrix-Scalar Multiplication. """ - def reduce(self) -> Expr: - raise NotImplementedError + + @property + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + if self.right.shape == Shape.scalar: + return self.left.shape + + elif self.left.shape == Shape.scalar: + return self.right.shape + + raise InvalidShape(f"Either {self.left} or {self.right} must be a scalar.") + +@binary_operator(MatrixAlgebra, MatrixAlgebra) +class MatInnerProd(Expr, MatrixAlgebra): + """ Inner product. """ @binary_operator(MatrixAlgebra, MatrixAlgebra) class MatMul(Expr, MatrixAlgebra): """ Matrix Multiplication. """ + + @property + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + ... + + def __repr__(self) -> str: + return f"({self.left} @ {self.right})" + + +@unary_operator(MatrixAlgebra) +class MatTranspose(HasRepr, MatrixAlgebra): + """ Matrix transposition """ + + # def to_repr(self, -- cgit v1.2.1