aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-03 00:10:05 +0100
committerNao Pross <np@0hm.ch>2024-03-03 00:10:05 +0100
commitd8a9e036c16d6680e34aca9a7676d4548fb313a3 (patch)
tree588d07b8c5793dbc03263b7f313a85360ab749b8
parentAdd Repr.is_zero to check if value is small (diff)
downloadmdpoly-d8a9e036c16d6680e34aca9a7676d4548fb313a3.tar.gz
mdpoly-d8a9e036c16d6680e34aca9a7676d4548fb313a3.zip
Add representation of basic algebraic operations
Implements Add, Multiply, Subtract, Divide and Exponentiate. Some are implemented directly, others as "reducible" i.e. in term of the others.
Diffstat (limited to '')
-rw-r--r--mdpoly/algebra.py121
-rw-r--r--mdpoly/representations.py10
2 files changed, 115 insertions, 16 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 8a8e961..4d8e47e 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -2,19 +2,24 @@
Algebraic Structures for Expressions
"""
-from .abc import Expr
+from .abc import Expr, Repr
from .leaves import Nothing, Const, Param
from .errors import AlgebraicError
-# from .state import State
-# from .representations import HasSparseRepr, SparseRepr
+from .state import State
+from .types import MatrixIndex, PolyIndex
+from .representations import HasRepr
from typing import Protocol, TypeVar, Any, runtime_checkable
-from functools import wraps
+from functools import wraps, reduce
+from itertools import product
from enum import Enum
from abc import abstractmethod
+import operator
+
def binary_operator(cls: Expr) -> Expr:
+ """ Class decorator to modify constructor of binary operators """
__init_cls__ = cls.__init__
@wraps(cls.__init__)
def __init_binary__(self, left, right, *args, **kwargs):
@@ -26,6 +31,7 @@ def binary_operator(cls: Expr) -> Expr:
def unary_operator(cls: Expr) -> Expr:
+ """ Class decorator to modify constructor of unary operators """
__init_cls__ = cls.__init__
@wraps(cls.__init__)
def __init_unary__(self, left, *args, **kwargs):
@@ -41,7 +47,7 @@ T = TypeVar("T")
@runtime_checkable
class AlgebraicStructure(Protocol):
"""
- Provides method to enforce algebraic closure of operations
+ Provides methods to enforce algebraic closure of operations
"""
class Algebra(Enum):
""" Types of algebras """
@@ -92,7 +98,17 @@ class AlgebraicStructure(Protocol):
-class ReducibleExpr(Protocol):
+class ReducibleExpr(HasRepr, Protocol):
+ """
+ Algebraic expression that can be written in terms of other (existing)
+ expression objects, i.e. can be reduced to another expression made of
+ simpler operations. For example subtraction can be written in term of
+ addition and multiplication.
+ """
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ return self.reduce().to_repr(repr_type, state)
+
@abstractmethod
def reduce(self) -> Expr: ...
@@ -113,10 +129,18 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
other = self._wrap_if_constant(other)
return Add(self, other)
+ def __radd__(self, other):
+ other = self._wrap_if_constant(other)
+ return Add(other, self)
+
def __sub__(self, other):
other = self._wrap_if_constant(other)
return Subtract(self, other)
+ def __rsub__(self, other):
+ other = self._wrap_if_constant(other)
+ return Subtract(other, self)
+
def __neg__(self, other):
return Multiply(self._constant(-1), other)
@@ -132,7 +156,8 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
other = self._wrap_if_constant(other)
if not self._is_const_or_param(other):
raise AlgebraicError("Cannot divide by variables in polynomial ring.")
- return Multiply(Const(value=1/other.value), self)
+
+ return Divide(self, other)
def __pow__(self, other):
other = self._wrap_if_constant(other)
@@ -143,25 +168,88 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
@binary_operator
-class Add(Expr, PolyRingAlgebra):
- """ Sum operator """
+class Add(Expr, HasRepr, PolyRingAlgebra):
+ """ Sum operator between scalar polynomials """
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ # Make a new empty representation
+ r = repr_type()
+ entry = MatrixIndex.scalar
+
+ # Make representations for existing stuff
+ for node in self.nodes():
+ nrepr, state = node.to_repr(repr_type, state)
+
+ # Process non-zero terms
+ for term in nrepr.terms(entry):
+ s = r.at(entry, term) + nrepr.at(entry, term)
+ r.set(entry, term, s)
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"({self.left} + {self.right})"
@binary_operator
class Subtract(Expr, ReducibleExpr, PolyRingAlgebra):
- """ Subtraction operation """
+ """ Subtraction operation between scalar polynomials """
+
def reduce(self) -> Expr:
return Add(self.left, Multiply(self._constant(value=-1), self.right))
+ def __repr__(self) -> str:
+ return f"({self.left} - {self.right})"
+
@binary_operator
-class Multiply(Expr, PolyRingAlgebra):
- """ Multiplication operator """
+class Multiply(Expr, HasRepr, PolyRingAlgebra):
+ """ Multiplication operator between scalar polynomials """
+
+ def to_repr(self, repr_type: type, state: State) -> Repr:
+ r = repr_type()
+
+ lrepr, state = self.left.to_repr(repr_type, state)
+ rrepr, state = self.right.to_repr(repr_type, state)
+
+ entry = MatrixIndex.scalar
+ for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
+ # Compute where the results should go
+ term = PolyIndex.product(lterm, rterm)
+
+ # Compute product
+ p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
+ r.set(entry, term, p)
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"({self.left} * {self.right})"
@binary_operator
-class Exponentiate(Expr, PolyRingAlgebra):
- """ Exponentiation operator """
+class Divide(Expr, ReducibleExpr, PolyRingAlgebra):
+ """ Division operator between scalar polynomials """
+
+ def reduce(self) -> Expr:
+ inverse = self._constant(value=1/self.right.value)
+ return Multiply(inverse, self.left)
+
+ def __repr__(self) -> str:
+ return f"({self.left} / {self.right})"
+
+
+@binary_operator
+class Exponentiate(Expr, ReducibleExpr, PolyRingAlgebra):
+ """ Exponentiation operator of scalar polynomials """
+
+ def reduce(self) -> Expr:
+ var = self.left
+ ntimes = self.right.value - 1
+ return reduce(operator.mul, (var for _ in range(ntimes)), var)
+
+ def __repr__(self) -> str:
+ return f"({self.left} ** {self.right})"
@unary_operator
@@ -213,12 +301,13 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
def transpose(self):
raise NotImplementedError
- # Shorthands & syntactic sugar
-
@property
def T(self):
return self.transpose()
+ def as_scalar(self) -> PolyRingAlgebra:
+ raise NotImplementedError
+
@binary_operator
class MatAdd(Expr, MatrixAlgebra):
diff --git a/mdpoly/representations.py b/mdpoly/representations.py
index 55763a9..5e58cb4 100644
--- a/mdpoly/representations.py
+++ b/mdpoly/representations.py
@@ -1,5 +1,6 @@
from .abc import Repr
from .types import Number, Shape, MatrixIndex, PolyIndex
+from .state import State
from typing import Protocol, Sequence
from abc import abstractmethod
@@ -8,6 +9,11 @@ import numpy as np
import numpy.typing as npt
+class HasRepr(Protocol):
+ @abstractmethod
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: ...
+
+
class SparseRepr(Repr):
"""
Sparse representation of polynomial
@@ -66,16 +72,20 @@ class SparseMatrixRepr(Repr):
def at(self, entry: MatrixIndex, term: PolyIndex) -> Number:
""" Access polynomial entry """
+ raise NotImplementedError
def set(self, entry: MatrixIndex, term: PolyIndex, value: Number) -> None:
""" Set value of polynomial entry """
+ raise NotImplementedError
def entries(self) -> Sequence[MatrixIndex]:
""" Return indices to non-zero entries of the matrix """
+ raise NotImplementedError
def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]:
""" Return indices to non-zero terms in the polynomial at the given
matrix entry """
+ raise NotImplementedError
def basis(self):
raise NotImplementedError