aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--mdpoly/algebra.py218
1 files changed, 131 insertions, 87 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 78e4965..b76a849 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -4,7 +4,7 @@ from .abc import Leaf, Expr, Repr
from .leaves import Nothing, Const, Param, Var
from .errors import AlgebraicError
from .state import State
-from .types import MatrixIndex, PolyIndex
+from .types import Shape, MatrixIndex, PolyIndex
from .representations import HasRepr
from typing import Protocol, TypeVar, Any, runtime_checkable
@@ -36,6 +36,10 @@ class AlgebraicStructure(Protocol):
_parameter: type
_constant: type
+ @property
+ @abstractmethod
+ def shape(self) -> Shape: ...
+
@classmethod
def _is_constant(cls: T, other: T) -> bool:
return isinstance(other, cls._constant)
@@ -75,7 +79,6 @@ class AlgebraicStructure(Protocol):
# f"objects with different shapes {cls.shape} and {other.shape}")
-
class ReducibleExpr(HasRepr, Protocol):
""" Reducible Expression
@@ -156,6 +159,81 @@ def unary_operator(inner_type: AlgebraicStructure):
return decorator
+class Add(Expr, HasRepr):
+ """ Generic addition (no type check) """
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
+ # Make a new empty representation
+ r = repr_type()
+ entry = MatrixIndex.scalar
+
+ # Make representations for existing stuff
+ for node in self.children():
+ nrepr, state = node.to_repr(repr_type, state)
+
+ # Process non-zero terms
+ for entry in nrepr.entries():
+ for term in nrepr.terms(entry):
+ s = r.at(entry, term) + nrepr.at(entry, term)
+ r.set(entry, term, s)
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"({self.left} + {self.right})"
+
+
+class Sub(Expr, ReducibleExpr):
+ """ Generic subtraction operator (no type check) """
+
+ def reduce(self) -> Expr:
+ """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
+ return self.left + (-1 * self.right)
+ # return Add(self.left, Mul(self._constant(value=-1), self.right))
+
+ def __repr__(self) -> str:
+ return f"({self.left} - {self.right})"
+
+
+class Mul(Expr, HasRepr):
+ """ Generic multiplication operator (no type check). """
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
+ r = repr_type()
+
+ lrepr, state = self.left.to_repr(repr_type, state)
+ rrepr, state = self.right.to_repr(repr_type, state)
+
+ entry = MatrixIndex.scalar
+ for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
+ # Compute where the results should go
+ term = PolyIndex.product(lterm, rterm)
+
+ # Compute product
+ p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
+ r.set(entry, term, p)
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"({self.left} * {self.right})"
+
+
+class Exp(Expr, ReducibleExpr):
+ """ Generic exponentiation (no type check). """
+
+ def reduce(self) -> Expr:
+ """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
+ var = self.left
+ ntimes = self.right.value - 1
+ return reduce(operator.mul, (var for _ in range(ntimes)), var)
+
+ def __repr__(self) -> str:
+ return f"({self.left} ** {self.right})"
+
+
# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓
# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫
# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹
@@ -166,139 +244,86 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
_parameter = Param
_constant = Const
+ @property
+ def shape(self):
+ return Shape.scalar
+
def __add__(self, other):
other = self._wrap_if_constant(other)
- return Add(self, other)
+ return PolyAdd(self, other)
def __radd__(self, other):
other = self._wrap_if_constant(other)
- return Add(other, self)
+ return PolyAdd(other, self)
def __sub__(self, other):
other = self._wrap_if_constant(other)
- return Subtract(self, other)
+ return PolySub(self, other)
def __rsub__(self, other):
other = self._wrap_if_constant(other)
- return Subtract(other, self)
+ return PolyAdd(other, self)
def __neg__(self, other):
- return Multiply(self._constant(-1), other)
+ return PolyMul(self._constant(-1), other)
def __mul__(self, other):
other = self._wrap_if_constant(other)
- return Multiply(self, other)
+ return PolyMul(self, other)
def __rmul__(self, other):
other = self._wrap_if_constant(other)
- return Multiply(other, self)
+ return PolyMul(other, self)
def __truediv__(self, other):
other = self._wrap_if_constant(other)
if not self._is_const_or_param(other):
raise AlgebraicError("Cannot divide by variables in polynomial ring.")
- return Divide(self, other)
+ return PolyMul(self, other)
def __pow__(self, other):
other = self._wrap_if_constant(other)
if not self._is_const_or_param(other):
raise AlgebraicError(f"Cannot raise to powers of type {type(other)} in "
"polynomial ring. Only constants and parameters are allowed.")
- return Exponentiate(self, other)
+ return PolyExp(self, other)
@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class Add(Expr, HasRepr, PolyRingAlgebra):
+class PolyAdd(Add, PolyRingAlgebra):
""" Addition operator between scalar polynomials. """
- def to_repr(self, repr_type: type, state: State) -> Repr:
- """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
- # Make a new empty representation
- r = repr_type()
- entry = MatrixIndex.scalar
-
- # Make representations for existing stuff
- for node in self.children():
- nrepr, state = node.to_repr(repr_type, state)
-
- # Process non-zero terms
- for term in nrepr.terms(entry):
- s = r.at(entry, term) + nrepr.at(entry, term)
- r.set(entry, term, s)
-
- return r, state
-
- def __repr__(self) -> str:
- return f"({self.left} + {self.right})"
-
@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class Subtract(Expr, ReducibleExpr, PolyRingAlgebra):
+class PolySub(Sub, PolyRingAlgebra):
""" Subtraction operator between scalar polynomials. """
- def reduce(self) -> Expr:
- """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
- return Add(self.left, Multiply(self._constant(value=-1), self.right))
-
- def __repr__(self) -> str:
- return f"({self.left} - {self.right})"
-
@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class Multiply(Expr, HasRepr, PolyRingAlgebra):
+class PolyMul(Mul, PolyRingAlgebra):
""" Multiplication operator between scalar polynomials. """
- def to_repr(self, repr_type: type, state: State) -> Repr:
- """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
- r = repr_type()
-
- lrepr, state = self.left.to_repr(repr_type, state)
- rrepr, state = self.right.to_repr(repr_type, state)
-
- entry = MatrixIndex.scalar
- for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
- # Compute where the results should go
- term = PolyIndex.product(lterm, rterm)
-
- # Compute product
- p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
- r.set(entry, term, p)
-
- return r, state
-
- def __repr__(self) -> str:
- return f"({self.left} * {self.right})"
-
@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class Divide(Expr, ReducibleExpr, PolyRingAlgebra):
- """ Division operator between scalar polynomials. """
+class PolyDiv(Expr, ReducibleExpr, PolyRingAlgebra):
+ """ Division of scalar polynomial by scalar. """
def reduce(self) -> Expr:
""" See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
inverse = self._constant(value=1/self.right.value)
- return Multiply(inverse, self.left)
+ return PolyMul(inverse, self.left)
def __repr__(self) -> str:
return f"({self.left} / {self.right})"
@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra):
+class PolyExp(Exp, PolyRingAlgebra):
""" Exponentiation operator of scalar polynomials. """
- def reduce(self) -> Expr:
- """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
- var = self.left
- ntimes = self.right.value - 1
- return reduce(operator.mul, (var for _ in range(ntimes)), var)
-
- def __repr__(self) -> str:
- return f"({self.left} ** {self.right})"
-
-@unary_operator
+@unary_operator(PolyRingAlgebra)
class PartialDiff(Expr, HasRepr, PolyRingAlgebra):
""" Partial differentiation of scalar polynomials """
def __init__(self, with_respect_to: Var):
@@ -342,17 +367,21 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
other = self._wrap_if_constant(other)
return MatAdd(self, other)
- def __sub___(self, other):
+ def __sub__(self, other):
other = self._wrap_if_constant(other)
- return MatAdd(self, ScalarMul(Const(value=-1), other))
+ return MatSub(self, other)
- def __mul__(self, scalar):
- scalar = self._wrap_if_constant(scalar)
- return ScalarMul(scalar, self)
+ def __rsub__(self, other):
+ other = self._wrap_if_constant(other)
+ return MatSub(other, self)
- def __rmul__(self, scalar):
- scalar = self._wrap_if_constant(scalar)
- return ScalarMul(scalar, self)
+ def __mul__(self, other):
+ other = self._wrap_if_constant(other)
+ return MatScalarMul(other, self)
+
+ def __rmul__(self, other):
+ scalar = self._wrap_if_constant(other)
+ return MatScalarMul(scalar, self)
def __matmul__(self, other):
other = self._wrap_if_constant(other)
@@ -375,19 +404,34 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
""" Shorthand for :py:meth:`mdpoly.algebra.MatrixAlgebra.transpose`. """
return self.transpose()
- def as_scalar(self) -> PolyRingAlgebra:
+ def to_scalar(self, scalar_type: type):
""" Convert to a scalar expression. """
raise NotImplementedError
@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class MatAdd(Expr, MatrixAlgebra):
+class MatAdd(Add, MatrixAlgebra):
""" Matrix Addition. """
@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class ScalarMul(Expr, MatrixAlgebra):
+class MatSub(Sub, MatrixAlgebra):
+ """ Matrix Subtraction. """
+
+
+@binary_operator(MatrixAlgebra, MatrixAlgebra)
+class MatElemMul(Mul, MatrixAlgebra):
+ """ Elementwise Matrix Multiplication """
+
+ def __repr__(self) -> str:
+ return f"({self.left} .* {self.right})"
+
+
+@binary_operator(MatrixAlgebra, MatrixAlgebra)
+class MatScalarMul(ReducibleExpr, MatrixAlgebra):
""" Matrix-Scalar Multiplication. """
+ def reduce(self) -> Expr:
+ raise NotImplementedError
@binary_operator(MatrixAlgebra, MatrixAlgebra)