diff options
Diffstat (limited to '')
-rw-r--r-- | mdpoly/algebra.py | 121 | ||||
-rw-r--r-- | mdpoly/representations.py | 10 |
2 files changed, 115 insertions, 16 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 8a8e961..4d8e47e 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -2,19 +2,24 @@ Algebraic Structures for Expressions """ -from .abc import Expr +from .abc import Expr, Repr from .leaves import Nothing, Const, Param from .errors import AlgebraicError -# from .state import State -# from .representations import HasSparseRepr, SparseRepr +from .state import State +from .types import MatrixIndex, PolyIndex +from .representations import HasRepr from typing import Protocol, TypeVar, Any, runtime_checkable -from functools import wraps +from functools import wraps, reduce +from itertools import product from enum import Enum from abc import abstractmethod +import operator + def binary_operator(cls: Expr) -> Expr: + """ Class decorator to modify constructor of binary operators """ __init_cls__ = cls.__init__ @wraps(cls.__init__) def __init_binary__(self, left, right, *args, **kwargs): @@ -26,6 +31,7 @@ def binary_operator(cls: Expr) -> Expr: def unary_operator(cls: Expr) -> Expr: + """ Class decorator to modify constructor of unary operators """ __init_cls__ = cls.__init__ @wraps(cls.__init__) def __init_unary__(self, left, *args, **kwargs): @@ -41,7 +47,7 @@ T = TypeVar("T") @runtime_checkable class AlgebraicStructure(Protocol): """ - Provides method to enforce algebraic closure of operations + Provides methods to enforce algebraic closure of operations """ class Algebra(Enum): """ Types of algebras """ @@ -92,7 +98,17 @@ class AlgebraicStructure(Protocol): -class ReducibleExpr(Protocol): +class ReducibleExpr(HasRepr, Protocol): + """ + Algebraic expression that can be written in terms of other (existing) + expression objects, i.e. can be reduced to another expression made of + simpler operations. For example subtraction can be written in term of + addition and multiplication. + """ + + def to_repr(self, repr_type: type, state: State) -> Repr: + return self.reduce().to_repr(repr_type, state) + @abstractmethod def reduce(self) -> Expr: ... @@ -113,10 +129,18 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol): other = self._wrap_if_constant(other) return Add(self, other) + def __radd__(self, other): + other = self._wrap_if_constant(other) + return Add(other, self) + def __sub__(self, other): other = self._wrap_if_constant(other) return Subtract(self, other) + def __rsub__(self, other): + other = self._wrap_if_constant(other) + return Subtract(other, self) + def __neg__(self, other): return Multiply(self._constant(-1), other) @@ -132,7 +156,8 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol): other = self._wrap_if_constant(other) if not self._is_const_or_param(other): raise AlgebraicError("Cannot divide by variables in polynomial ring.") - return Multiply(Const(value=1/other.value), self) + + return Divide(self, other) def __pow__(self, other): other = self._wrap_if_constant(other) @@ -143,25 +168,88 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol): @binary_operator -class Add(Expr, PolyRingAlgebra): - """ Sum operator """ +class Add(Expr, HasRepr, PolyRingAlgebra): + """ Sum operator between scalar polynomials """ + + def to_repr(self, repr_type: type, state: State) -> Repr: + # Make a new empty representation + r = repr_type() + entry = MatrixIndex.scalar + + # Make representations for existing stuff + for node in self.nodes(): + nrepr, state = node.to_repr(repr_type, state) + + # Process non-zero terms + for term in nrepr.terms(entry): + s = r.at(entry, term) + nrepr.at(entry, term) + r.set(entry, term, s) + + return r, state + + def __repr__(self) -> str: + return f"({self.left} + {self.right})" @binary_operator class Subtract(Expr, ReducibleExpr, PolyRingAlgebra): - """ Subtraction operation """ + """ Subtraction operation between scalar polynomials """ + def reduce(self) -> Expr: return Add(self.left, Multiply(self._constant(value=-1), self.right)) + def __repr__(self) -> str: + return f"({self.left} - {self.right})" + @binary_operator -class Multiply(Expr, PolyRingAlgebra): - """ Multiplication operator """ +class Multiply(Expr, HasRepr, PolyRingAlgebra): + """ Multiplication operator between scalar polynomials """ + + def to_repr(self, repr_type: type, state: State) -> Repr: + r = repr_type() + + lrepr, state = self.left.to_repr(repr_type, state) + rrepr, state = self.right.to_repr(repr_type, state) + + entry = MatrixIndex.scalar + for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)): + # Compute where the results should go + term = PolyIndex.product(lterm, rterm) + + # Compute product + p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm) + r.set(entry, term, p) + + return r, state + + def __repr__(self) -> str: + return f"({self.left} * {self.right})" @binary_operator -class Exponentiate(Expr, PolyRingAlgebra): - """ Exponentiation operator """ +class Divide(Expr, ReducibleExpr, PolyRingAlgebra): + """ Division operator between scalar polynomials """ + + def reduce(self) -> Expr: + inverse = self._constant(value=1/self.right.value) + return Multiply(inverse, self.left) + + def __repr__(self) -> str: + return f"({self.left} / {self.right})" + + +@binary_operator +class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra): + """ Exponentiation operator of scalar polynomials """ + + def reduce(self) -> Expr: + var = self.left + ntimes = self.right.value - 1 + return reduce(operator.mul, (var for _ in range(ntimes)), var) + + def __repr__(self) -> str: + return f"({self.left} ** {self.right})" @unary_operator @@ -213,12 +301,13 @@ class MatrixAlgebra(AlgebraicStructure, Protocol): def transpose(self): raise NotImplementedError - # Shorthands & syntactic sugar - @property def T(self): return self.transpose() + def as_scalar(self) -> PolyRingAlgebra: + raise NotImplementedError + @binary_operator class MatAdd(Expr, MatrixAlgebra): diff --git a/mdpoly/representations.py b/mdpoly/representations.py index 55763a9..5e58cb4 100644 --- a/mdpoly/representations.py +++ b/mdpoly/representations.py @@ -1,5 +1,6 @@ from .abc import Repr from .types import Number, Shape, MatrixIndex, PolyIndex +from .state import State from typing import Protocol, Sequence from abc import abstractmethod @@ -8,6 +9,11 @@ import numpy as np import numpy.typing as npt +class HasRepr(Protocol): + @abstractmethod + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: ... + + class SparseRepr(Repr): """ Sparse representation of polynomial @@ -66,16 +72,20 @@ class SparseMatrixRepr(Repr): def at(self, entry: MatrixIndex, term: PolyIndex) -> Number: """ Access polynomial entry """ + raise NotImplementedError def set(self, entry: MatrixIndex, term: PolyIndex, value: Number) -> None: """ Set value of polynomial entry """ + raise NotImplementedError def entries(self) -> Sequence[MatrixIndex]: """ Return indices to non-zero entries of the matrix """ + raise NotImplementedError def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]: """ Return indices to non-zero terms in the polynomial at the given matrix entry """ + raise NotImplementedError def basis(self): raise NotImplementedError |