aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/algebra.py121
-rw-r--r--mdpoly/representations.py10
2 files changed, 115 insertions, 16 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 8a8e961..4d8e47e 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -2,19 +2,24 @@
Algebraic Structures for Expressions
"""
-from .abc import Expr
+from .abc import Expr, Repr
from .leaves import Nothing, Const, Param
from .errors import AlgebraicError
-# from .state import State
-# from .representations import HasSparseRepr, SparseRepr
+from .state import State
+from .types import MatrixIndex, PolyIndex
+from .representations import HasRepr
from typing import Protocol, TypeVar, Any, runtime_checkable
-from functools import wraps
+from functools import wraps, reduce
+from itertools import product
from enum import Enum
from abc import abstractmethod
+import operator
+
def binary_operator(cls: Expr) -> Expr:
+ """ Class decorator to modify constructor of binary operators """
__init_cls__ = cls.__init__
@wraps(cls.__init__)
def __init_binary__(self, left, right, *args, **kwargs):
@@ -26,6 +31,7 @@ def binary_operator(cls: Expr) -> Expr:
def unary_operator(cls: Expr) -> Expr:
+ """ Class decorator to modify constructor of unary operators """
__init_cls__ = cls.__init__
@wraps(cls.__init__)
def __init_unary__(self, left, *args, **kwargs):
@@ -41,7 +47,7 @@ T = TypeVar("T")
@runtime_checkable
class AlgebraicStructure(Protocol):
"""
- Provides method to enforce algebraic closure of operations
+ Provides methods to enforce algebraic closure of operations
"""
class Algebra(Enum):
""" Types of algebras """
@@ -92,7 +98,17 @@ class AlgebraicStructure(Protocol):
-class ReducibleExpr(Protocol):
+class ReducibleExpr(HasRepr, Protocol):
+ """
+ Algebraic expression that can be written in terms of other (existing)
+ expression objects, i.e. can be reduced to another expression made of
+ simpler operations. For example subtraction can be written in term of
+ addition and multiplication.
+ """
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ return self.reduce().to_repr(repr_type, state)
+
@abstractmethod
def reduce(self) -> Expr: ...
@@ -113,10 +129,18 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
other = self._wrap_if_constant(other)
return Add(self, other)
+ def __radd__(self, other):
+ other = self._wrap_if_constant(other)
+ return Add(other, self)
+
def __sub__(self, other):
other = self._wrap_if_constant(other)
return Subtract(self, other)
+ def __rsub__(self, other):
+ other = self._wrap_if_constant(other)
+ return Subtract(other, self)
+
def __neg__(self, other):
return Multiply(self._constant(-1), other)
@@ -132,7 +156,8 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
other = self._wrap_if_constant(other)
if not self._is_const_or_param(other):
raise AlgebraicError("Cannot divide by variables in polynomial ring.")
- return Multiply(Const(value=1/other.value), self)
+
+ return Divide(self, other)
def __pow__(self, other):
other = self._wrap_if_constant(other)
@@ -143,25 +168,88 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
@binary_operator
-class Add(Expr, PolyRingAlgebra):
- """ Sum operator """
+class Add(Expr, HasRepr, PolyRingAlgebra):
+ """ Sum operator between scalar polynomials """
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ # Make a new empty representation
+ r = repr_type()
+ entry = MatrixIndex.scalar
+
+ # Make representations for existing stuff
+ for node in self.nodes():
+ nrepr, state = node.to_repr(repr_type, state)
+
+ # Process non-zero terms
+ for term in nrepr.terms(entry):
+ s = r.at(entry, term) + nrepr.at(entry, term)
+ r.set(entry, term, s)
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"({self.left} + {self.right})"
@binary_operator
class Subtract(Expr, ReducibleExpr, PolyRingAlgebra):
- """ Subtraction operation """
+ """ Subtraction operation between scalar polynomials """
+
def reduce(self) -> Expr:
return Add(self.left, Multiply(self._constant(value=-1), self.right))
+ def __repr__(self) -> str:
+ return f"({self.left} - {self.right})"
+
@binary_operator
-class Multiply(Expr, PolyRingAlgebra):
- """ Multiplication operator """
+class Multiply(Expr, HasRepr, PolyRingAlgebra):
+ """ Multiplication operator between scalar polynomials """
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ r = repr_type()
+
+ lrepr, state = self.left.to_repr(repr_type, state)
+ rrepr, state = self.right.to_repr(repr_type, state)
+
+ entry = MatrixIndex.scalar
+ for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
+ # Compute where the results should go
+ term = PolyIndex.product(lterm, rterm)
+
+ # Compute product
+ p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
+ r.set(entry, term, p)
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"({self.left} * {self.right})"
@binary_operator
-class Exponentiate(Expr, PolyRingAlgebra):
- """ Exponentiation operator """
+class Divide(Expr, ReducibleExpr, PolyRingAlgebra):
+ """ Division operator between scalar polynomials """
+
+ def reduce(self) -> Expr:
+ inverse = self._constant(value=1/self.right.value)
+ return Multiply(inverse, self.left)
+
+ def __repr__(self) -> str:
+ return f"({self.left} / {self.right})"
+
+
+@binary_operator
+class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra):
+ """ Exponentiation operator of scalar polynomials """
+
+ def reduce(self) -> Expr:
+ var = self.left
+ ntimes = self.right.value - 1
+ return reduce(operator.mul, (var for _ in range(ntimes)), var)
+
+ def __repr__(self) -> str:
+ return f"({self.left} ** {self.right})"
@unary_operator
@@ -213,12 +301,13 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
def transpose(self):
raise NotImplementedError
- # Shorthands & syntactic sugar
-
@property
def T(self):
return self.transpose()
+ def as_scalar(self) -> PolyRingAlgebra:
+ raise NotImplementedError
+
@binary_operator
class MatAdd(Expr, MatrixAlgebra):
diff --git a/mdpoly/representations.py b/mdpoly/representations.py
index 55763a9..5e58cb4 100644
--- a/mdpoly/representations.py
+++ b/mdpoly/representations.py
@@ -1,5 +1,6 @@
from .abc import Repr
from .types import Number, Shape, MatrixIndex, PolyIndex
+from .state import State
from typing import Protocol, Sequence
from abc import abstractmethod
@@ -8,6 +9,11 @@ import numpy as np
import numpy.typing as npt
+class HasRepr(Protocol):
+ @abstractmethod
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: ...
+
+
class SparseRepr(Repr):
"""
Sparse representation of polynomial
@@ -66,16 +72,20 @@ class SparseMatrixRepr(Repr):
def at(self, entry: MatrixIndex, term: PolyIndex) -> Number:
""" Access polynomial entry """
+ raise NotImplementedError
def set(self, entry: MatrixIndex, term: PolyIndex, value: Number) -> None:
""" Set value of polynomial entry """
+ raise NotImplementedError
def entries(self) -> Sequence[MatrixIndex]:
""" Return indices to non-zero entries of the matrix """
+ raise NotImplementedError
def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]:
""" Return indices to non-zero terms in the polynomial at the given
matrix entry """
+ raise NotImplementedError
def basis(self):
raise NotImplementedError