diff options
-rw-r--r-- | mdpoly/algebra.py | 218 |
1 files changed, 131 insertions, 87 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 78e4965..b76a849 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -4,7 +4,7 @@ from .abc import Leaf, Expr, Repr from .leaves import Nothing, Const, Param, Var from .errors import AlgebraicError from .state import State -from .types import MatrixIndex, PolyIndex +from .types import Shape, MatrixIndex, PolyIndex from .representations import HasRepr from typing import Protocol, TypeVar, Any, runtime_checkable @@ -36,6 +36,10 @@ class AlgebraicStructure(Protocol): _parameter: type _constant: type + @property + @abstractmethod + def shape(self) -> Shape: ... + @classmethod def _is_constant(cls: T, other: T) -> bool: return isinstance(other, cls._constant) @@ -75,7 +79,6 @@ class AlgebraicStructure(Protocol): # f"objects with different shapes {cls.shape} and {other.shape}") - class ReducibleExpr(HasRepr, Protocol): """ Reducible Expression @@ -156,6 +159,81 @@ def unary_operator(inner_type: AlgebraicStructure): return decorator +class Add(Expr, HasRepr): + """ Generic addition (no type check) """ + + def to_repr(self, repr_type: type, state: State) -> Repr: + """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ + # Make a new empty representation + r = repr_type() + entry = MatrixIndex.scalar + + # Make representations for existing stuff + for node in self.children(): + nrepr, state = node.to_repr(repr_type, state) + + # Process non-zero terms + for entry in nrepr.entries(): + for term in nrepr.terms(entry): + s = r.at(entry, term) + nrepr.at(entry, term) + r.set(entry, term, s) + + return r, state + + def __repr__(self) -> str: + return f"({self.left} + {self.right})" + + +class Sub(Expr, ReducibleExpr): + """ Generic subtraction operator (no type check) """ + + def reduce(self) -> Expr: + """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ + return self.left + (-1 * self.right) + # return Add(self.left, Mul(self._constant(value=-1), self.right)) + + def __repr__(self) -> str: + return f"({self.left} - {self.right})" + + +class Mul(Expr, HasRepr): + """ Generic multiplication operator (no type check). """ + + def to_repr(self, repr_type: type, state: State) -> Repr: + """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ + r = repr_type() + + lrepr, state = self.left.to_repr(repr_type, state) + rrepr, state = self.right.to_repr(repr_type, state) + + entry = MatrixIndex.scalar + for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)): + # Compute where the results should go + term = PolyIndex.product(lterm, rterm) + + # Compute product + p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm) + r.set(entry, term, p) + + return r, state + + def __repr__(self) -> str: + return f"({self.left} * {self.right})" + + +class Exp(Expr, ReducibleExpr): + """ Generic exponentiation (no type check). """ + + def reduce(self) -> Expr: + """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ + var = self.left + ntimes = self.right.value - 1 + return reduce(operator.mul, (var for _ in range(ntimes)), var) + + def __repr__(self) -> str: + return f"({self.left} ** {self.right})" + + # ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓ # ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫ # ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ @@ -166,139 +244,86 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol): _parameter = Param _constant = Const + @property + def shape(self): + return Shape.scalar + def __add__(self, other): other = self._wrap_if_constant(other) - return Add(self, other) + return PolyAdd(self, other) def __radd__(self, other): other = self._wrap_if_constant(other) - return Add(other, self) + return PolyAdd(other, self) def __sub__(self, other): other = self._wrap_if_constant(other) - return Subtract(self, other) + return PolySub(self, other) def __rsub__(self, other): other = self._wrap_if_constant(other) - return Subtract(other, self) + return PolyAdd(other, self) def __neg__(self, other): - return Multiply(self._constant(-1), other) + return PolyMul(self._constant(-1), other) def __mul__(self, other): other = self._wrap_if_constant(other) - return Multiply(self, other) + return PolyMul(self, other) def __rmul__(self, other): other = self._wrap_if_constant(other) - return Multiply(other, self) + return PolyMul(other, self) def __truediv__(self, other): other = self._wrap_if_constant(other) if not self._is_const_or_param(other): raise AlgebraicError("Cannot divide by variables in polynomial ring.") - return Divide(self, other) + return PolyMul(self, other) def __pow__(self, other): other = self._wrap_if_constant(other) if not self._is_const_or_param(other): raise AlgebraicError(f"Cannot raise to powers of type {type(other)} in " "polynomial ring. Only constants and parameters are allowed.") - return Exponentiate(self, other) + return PolyExp(self, other) @binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class Add(Expr, HasRepr, PolyRingAlgebra): +class PolyAdd(Add, PolyRingAlgebra): """ Addition operator between scalar polynomials. """ - def to_repr(self, repr_type: type, state: State) -> Repr: - """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ - # Make a new empty representation - r = repr_type() - entry = MatrixIndex.scalar - - # Make representations for existing stuff - for node in self.children(): - nrepr, state = node.to_repr(repr_type, state) - - # Process non-zero terms - for term in nrepr.terms(entry): - s = r.at(entry, term) + nrepr.at(entry, term) - r.set(entry, term, s) - - return r, state - - def __repr__(self) -> str: - return f"({self.left} + {self.right})" - @binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class Subtract(Expr, ReducibleExpr, PolyRingAlgebra): +class PolySub(Sub, PolyRingAlgebra): """ Subtraction operator between scalar polynomials. """ - def reduce(self) -> Expr: - """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ - return Add(self.left, Multiply(self._constant(value=-1), self.right)) - - def __repr__(self) -> str: - return f"({self.left} - {self.right})" - @binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class Multiply(Expr, HasRepr, PolyRingAlgebra): +class PolyMul(Mul, PolyRingAlgebra): """ Multiplication operator between scalar polynomials. """ - def to_repr(self, repr_type: type, state: State) -> Repr: - """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ - r = repr_type() - - lrepr, state = self.left.to_repr(repr_type, state) - rrepr, state = self.right.to_repr(repr_type, state) - - entry = MatrixIndex.scalar - for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)): - # Compute where the results should go - term = PolyIndex.product(lterm, rterm) - - # Compute product - p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm) - r.set(entry, term, p) - - return r, state - - def __repr__(self) -> str: - return f"({self.left} * {self.right})" - @binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class Divide(Expr, ReducibleExpr, PolyRingAlgebra): - """ Division operator between scalar polynomials. """ +class PolyDiv(Expr, ReducibleExpr, PolyRingAlgebra): + """ Division of scalar polynomial by scalar. """ def reduce(self) -> Expr: """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ inverse = self._constant(value=1/self.right.value) - return Multiply(inverse, self.left) + return PolyMul(inverse, self.left) def __repr__(self) -> str: return f"({self.left} / {self.right})" @binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra): +class PolyExp(Exp, PolyRingAlgebra): """ Exponentiation operator of scalar polynomials. """ - def reduce(self) -> Expr: - """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ - var = self.left - ntimes = self.right.value - 1 - return reduce(operator.mul, (var for _ in range(ntimes)), var) - - def __repr__(self) -> str: - return f"({self.left} ** {self.right})" - -@unary_operator +@unary_operator(PolyRingAlgebra) class PartialDiff(Expr, HasRepr, PolyRingAlgebra): """ Partial differentiation of scalar polynomials """ def __init__(self, with_respect_to: Var): @@ -342,17 +367,21 @@ class MatrixAlgebra(AlgebraicStructure, Protocol): other = self._wrap_if_constant(other) return MatAdd(self, other) - def __sub___(self, other): + def __sub__(self, other): other = self._wrap_if_constant(other) - return MatAdd(self, ScalarMul(Const(value=-1), other)) + return MatSub(self, other) - def __mul__(self, scalar): - scalar = self._wrap_if_constant(scalar) - return ScalarMul(scalar, self) + def __rsub__(self, other): + other = self._wrap_if_constant(other) + return MatSub(other, self) - def __rmul__(self, scalar): - scalar = self._wrap_if_constant(scalar) - return ScalarMul(scalar, self) + def __mul__(self, other): + other = self._wrap_if_constant(other) + return MatScalarMul(other, self) + + def __rmul__(self, other): + scalar = self._wrap_if_constant(other) + return MatScalarMul(scalar, self) def __matmul__(self, other): other = self._wrap_if_constant(other) @@ -375,19 +404,34 @@ class MatrixAlgebra(AlgebraicStructure, Protocol): """ Shorthand for :py:meth:`mdpoly.algebra.MatrixAlgebra.transpose`. """ return self.transpose() - def as_scalar(self) -> PolyRingAlgebra: + def to_scalar(self, scalar_type: type): """ Convert to a scalar expression. """ raise NotImplementedError @binary_operator(MatrixAlgebra, MatrixAlgebra) -class MatAdd(Expr, MatrixAlgebra): +class MatAdd(Add, MatrixAlgebra): """ Matrix Addition. """ @binary_operator(MatrixAlgebra, MatrixAlgebra) -class ScalarMul(Expr, MatrixAlgebra): +class MatSub(Sub, MatrixAlgebra): + """ Matrix Subtraction. """ + + +@binary_operator(MatrixAlgebra, MatrixAlgebra) +class MatElemMul(Mul, MatrixAlgebra): + """ Elementwise Matrix Multiplication """ + + def __repr__(self) -> str: + return f"({self.left} .* {self.right})" + + +@binary_operator(MatrixAlgebra, MatrixAlgebra) +class MatScalarMul(ReducibleExpr, MatrixAlgebra): """ Matrix-Scalar Multiplication. """ + def reduce(self) -> Expr: + raise NotImplementedError @binary_operator(MatrixAlgebra, MatrixAlgebra) |