aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/algebra.py64
1 files changed, 42 insertions, 22 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 6f046b7..d44a528 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -1,6 +1,6 @@
""" Algebraic Structures for Expressions """
-from .abc import Expr, Repr
+from .abc import Leaf, Expr, Repr
from .leaves import Nothing, Const, Param, Var
from .errors import AlgebraicError
from .state import State
@@ -16,18 +16,6 @@ from abc import abstractmethod
import operator
-def binary_operator(cls: Expr) -> Expr:
- """ Class decorator to modify constructor of binary operators. """
- __init_cls__ = cls.__init__
- @wraps(cls.__init__)
- def __init_binary__(self, left, right, *args, **kwargs):
- __init_cls__(self, *args, **kwargs)
- self.left, self.right = left, right
-
- cls.__init__ = __init_binary__
- return cls
-
-
def unary_operator(cls: Expr) -> Expr:
""" Class decorator to modify constructor of unary operators. """
__init_cls__ = cls.__init__
@@ -119,13 +107,45 @@ class ReducibleExpr(HasRepr, Protocol):
""" Reduce the expression to its basic elements """
+def binary_operator(left_type: AlgebraicStructure, right_type: AlgebraicStructure):
+ """ Class decorator that adds constructor for binary operations of Expr.
+
+ Classes that inherit :py:class:`mdpoly.abc.Expr` take values for *left* and
+ *right* to represent the operands. This binary operator specifies the
+ algebra (:py:class:`mdpoly.algebra.AlgebraicStructure`) that *left* and
+ *right* have to respect in order for the operation to be correct.
+ Concretely, this is to raise and exception for eg. when a scalar is added
+ to a matrix, etc.
+ """
+ # TODO: add right_shape and left_shape for matrices
+ def decorator(cls: Expr) -> Expr:
+ init_cls = cls.__init__
+ @wraps(cls.__init__)
+ def new_init_cls(self, left, right, *args, **kwargs):
+ init_cls(self, *args, **kwargs)
+ # Wrong algebra
+ if not isinstance(left, left_type) or not isinstance(right, right_type):
+ # None of the two is a Leaf. This is a workaround because
+ # adding the algebra to the base Leaf types Const, Var, etc. is
+ # not possible without having circular imports. For the exported types
+ # Constant, Variable, etc. this is not a problem.
+ if not isinstance(left, Leaf) and not isinstance(right, Leaf):
+ raise AlgebraicError(
+ "Cannot perform operation between types from "
+ f"different algebraic structures {type(left)} ({left._algebra}) "
+ f"and {type(right)} ({right._algebra})")
+
+ self.left, self.right = left, right
+ cls.__init__ = new_init_cls
+ return cls
+ return decorator
+
# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓
# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫
# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹
class PolyRingAlgebra(AlgebraicStructure, Protocol):
""" Endows with the algebraic structure of a polynomial ring. """
-
_algebra = AlgebraicStructure.Algebra.poly_ring
_parameter = Param
_constant = Const
@@ -172,7 +192,7 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
return Exponentiate(self, other)
-@binary_operator
+@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
class Add(Expr, HasRepr, PolyRingAlgebra):
""" Addition operator between scalar polynomials. """
@@ -197,7 +217,7 @@ class Add(Expr, HasRepr, PolyRingAlgebra):
return f"({self.left} + {self.right})"
-@binary_operator
+@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
class Subtract(Expr, ReducibleExpr, PolyRingAlgebra):
""" Subtraction operator between scalar polynomials. """
@@ -209,7 +229,7 @@ class Subtract(Expr, ReducibleExpr, PolyRingAlgebra):
return f"({self.left} - {self.right})"
-@binary_operator
+@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
class Multiply(Expr, HasRepr, PolyRingAlgebra):
""" Multiplication operator between scalar polynomials. """
@@ -235,7 +255,7 @@ class Multiply(Expr, HasRepr, PolyRingAlgebra):
return f"({self.left} * {self.right})"
-@binary_operator
+@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
class Divide(Expr, ReducibleExpr, PolyRingAlgebra):
""" Division operator between scalar polynomials. """
@@ -248,7 +268,7 @@ class Divide(Expr, ReducibleExpr, PolyRingAlgebra):
return f"({self.left} / {self.right})"
-@binary_operator
+@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra):
""" Exponentiation operator of scalar polynomials. """
@@ -344,16 +364,16 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
raise NotImplementedError
-@binary_operator
+@binary_operator(MatrixAlgebra, MatrixAlgebra)
class MatAdd(Expr, MatrixAlgebra):
""" Matrix Addition. """
-@binary_operator
+@binary_operator(MatrixAlgebra, MatrixAlgebra)
class ScalarMul(Expr, MatrixAlgebra):
""" Matrix-Scalar Multiplication. """
-@binary_operator
+@binary_operator(MatrixAlgebra, MatrixAlgebra)
class MatMul(Expr, MatrixAlgebra):
""" Matrix Multiplication. """