aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-04 01:30:28 +0100
committerNao Pross <np@0hm.ch>2024-03-04 01:30:28 +0100
commit8c4caecc1cc1cfbacbd38a69a6aee4a87e317268 (patch)
tree7a7e34e587754db0405c255deaffce4bc3909093
parentAdd shape checks for PolyRingAlgebra (diff)
downloadmdpoly-8c4caecc1cc1cfbacbd38a69a6aee4a87e317268.tar.gz
mdpoly-8c4caecc1cc1cfbacbd38a69a6aee4a87e317268.zip
Add shape check for MatScalarMul
-rw-r--r--mdpoly/algebra.py49
1 files changed, 41 insertions, 8 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 46a82b6..ef3f7eb 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -8,7 +8,7 @@ from .types import Shape, MatrixIndex, PolyIndex
from .representations import HasRepr
from typing import Protocol, TypeVar, Any, runtime_checkable
-from functools import wraps, reduce
+from functools import wraps, reduce, cached_property
from itertools import product
from enum import Enum
from abc import abstractmethod
@@ -158,6 +158,7 @@ def unary_operator(inner_type: AlgebraicStructure):
class Add(Expr, HasRepr):
""" Generic addition (no type check) """
+ @property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
if self.left.shape != self.right.shape:
@@ -189,6 +190,7 @@ class Add(Expr, HasRepr):
class Sub(Expr, ReducibleExpr):
""" Generic subtraction operator (no type check) """
+ @property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
if self.left.shape != self.right.shape:
@@ -207,6 +209,7 @@ class Sub(Expr, ReducibleExpr):
class Mul(Expr, HasRepr):
""" Generic multiplication operator (no type check). """
+ @property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
if self.left.shape != self.right.shape:
@@ -238,6 +241,7 @@ class Mul(Expr, HasRepr):
class Exp(Expr, ReducibleExpr):
""" Generic exponentiation (no type check). """
+ @property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
if self.left.shape != self.right.shape:
@@ -349,6 +353,7 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra):
def __init__(self, with_respect_to: Var):
self.wrt = with_respect_to
+ @property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
return self.inner.shape
@@ -442,16 +447,16 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
def transpose(self):
""" Matrix transposition. """
- raise NotImplementedError
+ return MatTranspose(self)
@property
def T(self):
""" Shorthand for :py:meth:`mdpoly.algebra.MatrixAlgebra.transpose`. """
return self.transpose()
- def to_scalar(self, scalar_type: type):
- """ Convert to a scalar expression. """
- raise NotImplementedError
+ # def to_scalar(self, scalar_type: type):
+ # """ Convert to a scalar expression. """
+ # raise NotImplementedError
@binary_operator(MatrixAlgebra, MatrixAlgebra)
@@ -466,7 +471,7 @@ class MatSub(Sub, MatrixAlgebra):
@binary_operator(MatrixAlgebra, MatrixAlgebra)
class MatElemMul(Mul, MatrixAlgebra):
- """ Elementwise Matrix Multiplication """
+ """ Elementwise Matrix Multiplication. """
def __repr__(self) -> str:
return f"({self.left} .* {self.right})"
@@ -475,10 +480,38 @@ class MatElemMul(Mul, MatrixAlgebra):
@binary_operator(MatrixAlgebra, MatrixAlgebra)
class MatScalarMul(ReducibleExpr, MatrixAlgebra):
""" Matrix-Scalar Multiplication. """
- def reduce(self) -> Expr:
- raise NotImplementedError
+
+ @property
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if self.right.shape == Shape.scalar:
+ return self.left.shape
+
+ elif self.left.shape == Shape.scalar:
+ return self.right.shape
+
+ raise InvalidShape(f"Either {self.left} or {self.right} must be a scalar.")
+
+@binary_operator(MatrixAlgebra, MatrixAlgebra)
+class MatInnerProd(Expr, MatrixAlgebra):
+ """ Inner product. """
@binary_operator(MatrixAlgebra, MatrixAlgebra)
class MatMul(Expr, MatrixAlgebra):
""" Matrix Multiplication. """
+
+ @property
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ ...
+
+ def __repr__(self) -> str:
+ return f"({self.left} @ {self.right})"
+
+
+@unary_operator(MatrixAlgebra)
+class MatTranspose(HasRepr, MatrixAlgebra):
+ """ Matrix transposition """
+
+ # def to_repr(self,