diff options
-rw-r--r-- | mdpoly/algebra.py | 64 |
1 files changed, 42 insertions, 22 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 6f046b7..d44a528 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -1,6 +1,6 @@ """ Algebraic Structures for Expressions """ -from .abc import Expr, Repr +from .abc import Leaf, Expr, Repr from .leaves import Nothing, Const, Param, Var from .errors import AlgebraicError from .state import State @@ -16,18 +16,6 @@ from abc import abstractmethod import operator -def binary_operator(cls: Expr) -> Expr: - """ Class decorator to modify constructor of binary operators. """ - __init_cls__ = cls.__init__ - @wraps(cls.__init__) - def __init_binary__(self, left, right, *args, **kwargs): - __init_cls__(self, *args, **kwargs) - self.left, self.right = left, right - - cls.__init__ = __init_binary__ - return cls - - def unary_operator(cls: Expr) -> Expr: """ Class decorator to modify constructor of unary operators. """ __init_cls__ = cls.__init__ @@ -119,13 +107,45 @@ class ReducibleExpr(HasRepr, Protocol): """ Reduce the expression to its basic elements """ +def binary_operator(left_type: AlgebraicStructure, right_type: AlgebraicStructure): + """ Class decorator that adds constructor for binary operations of Expr. + + Classes that inherit :py:class:`mdpoly.abc.Expr` take values for *left* and + *right* to represent the operands. This binary operator specifies the + algebra (:py:class:`mdpoly.algebra.AlgebraicStructure`) that *left* and + *right* have to respect in order for the operation to be correct. + Concretely, this is to raise and exception for eg. when a scalar is added + to a matrix, etc. + """ + # TODO: add right_shape and left_shape for matrices + def decorator(cls: Expr) -> Expr: + init_cls = cls.__init__ + @wraps(cls.__init__) + def new_init_cls(self, left, right, *args, **kwargs): + init_cls(self, *args, **kwargs) + # Wrong algebra + if not isinstance(left, left_type) or not isinstance(right, right_type): + # None of the two is a Leaf. This is a workaround because + # adding the algebra to the base Leaf types Const, Var, etc. is + # not possible without having circular imports. For the exported types + # Constant, Variable, etc. this is not a problem. + if not isinstance(left, Leaf) and not isinstance(right, Leaf): + raise AlgebraicError( + "Cannot perform operation between types from " + f"different algebraic structures {type(left)} ({left._algebra}) " + f"and {type(right)} ({right._algebra})") + + self.left, self.right = left, right + cls.__init__ = new_init_cls + return cls + return decorator + # ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓ # ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫ # ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ class PolyRingAlgebra(AlgebraicStructure, Protocol): """ Endows with the algebraic structure of a polynomial ring. """ - _algebra = AlgebraicStructure.Algebra.poly_ring _parameter = Param _constant = Const @@ -172,7 +192,7 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol): return Exponentiate(self, other) -@binary_operator +@binary_operator(PolyRingAlgebra, PolyRingAlgebra) class Add(Expr, HasRepr, PolyRingAlgebra): """ Addition operator between scalar polynomials. """ @@ -197,7 +217,7 @@ class Add(Expr, HasRepr, PolyRingAlgebra): return f"({self.left} + {self.right})" -@binary_operator +@binary_operator(PolyRingAlgebra, PolyRingAlgebra) class Subtract(Expr, ReducibleExpr, PolyRingAlgebra): """ Subtraction operator between scalar polynomials. """ @@ -209,7 +229,7 @@ class Subtract(Expr, ReducibleExpr, PolyRingAlgebra): return f"({self.left} - {self.right})" -@binary_operator +@binary_operator(PolyRingAlgebra, PolyRingAlgebra) class Multiply(Expr, HasRepr, PolyRingAlgebra): """ Multiplication operator between scalar polynomials. """ @@ -235,7 +255,7 @@ class Multiply(Expr, HasRepr, PolyRingAlgebra): return f"({self.left} * {self.right})" -@binary_operator +@binary_operator(PolyRingAlgebra, PolyRingAlgebra) class Divide(Expr, ReducibleExpr, PolyRingAlgebra): """ Division operator between scalar polynomials. """ @@ -248,7 +268,7 @@ class Divide(Expr, ReducibleExpr, PolyRingAlgebra): return f"({self.left} / {self.right})" -@binary_operator +@binary_operator(PolyRingAlgebra, PolyRingAlgebra) class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra): """ Exponentiation operator of scalar polynomials. """ @@ -344,16 +364,16 @@ class MatrixAlgebra(AlgebraicStructure, Protocol): raise NotImplementedError -@binary_operator +@binary_operator(MatrixAlgebra, MatrixAlgebra) class MatAdd(Expr, MatrixAlgebra): """ Matrix Addition. """ -@binary_operator +@binary_operator(MatrixAlgebra, MatrixAlgebra) class ScalarMul(Expr, MatrixAlgebra): """ Matrix-Scalar Multiplication. """ -@binary_operator +@binary_operator(MatrixAlgebra, MatrixAlgebra) class MatMul(Expr, MatrixAlgebra): """ Matrix Multiplication. """ |