aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/abc.py12
-rw-r--r--mdpoly/algebra.py54
2 files changed, 30 insertions, 36 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index 245f3a2..989fa7f 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -18,11 +18,7 @@ from dataclassabc import dataclassabc
from hashlib import sha256
-# Forward declaration of Repr class for TypeVar
-# Why is this shit necessary? Feels like writing C++
-class Repr: ...
-
-ReprT = TypeVar("ReprT", bound=Repr)
+ReprT = TypeVar("ReprT", bound="Repr")
# ┏━╸╻ ╻┏━┓┏━┓┏━╸┏━┓┏━┓╻┏━┓┏┓╻┏━┓
# ┣╸ ┏╋┛┣━┛┣┳┛┣╸ ┗━┓┗━┓┃┃ ┃┃┗┫┗━┓
@@ -109,7 +105,7 @@ class Expr(ABC):
@abstractmethod
def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- """ Convert to a concrete representation
+ """ Convert to a concrete representation.
To convert an abstract object from the expression tree, we ... TODO: finish.
@@ -125,13 +121,13 @@ class Expr(ABC):
return h.digest()
def children(self) -> Sequence[Expr]:
- """ Iterate over the two nodes """
+ """ Iterate over the two nodes. """
if self.is_leaf:
return tuple()
return self.left, self.right
def leaves(self) -> Iterable[Expr]:
- """ Returns the leaves of the tree. This is done recursively and in
+ r""" Returns the leaves of the tree. This is done recursively and in
:math:`\mathcal{O}(n)`."""
if self.left.is_leaf:
yield self.left
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 38373c0..04a61ad 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -1,7 +1,7 @@
""" Algebraic Structures for Expressions """
from __future__ import annotations
-from .abc import Algebra, Expr, Repr, Nothing, Const, Var, Param
+from .abc import Algebra, Expr, ReprT, Nothing, Const, Var, Param
from .errors import AlgebraicError, InvalidShape, MissingParameters
from .state import State
from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number
@@ -15,11 +15,8 @@ from dataclassabc import dataclassabc
import operator
-ReprT = TypeVar("ReprT", bound=Repr)
-
-
class BinaryOp(Expr):
- """ Binary Operator """
+ """ Binary Operator. TODO: desc """
def __init__(self, left, right):
self._left = left
self._right = right
@@ -42,7 +39,7 @@ class BinaryOp(Expr):
class UnaryOp(Expr):
- """ Unary Operator """
+ """ Unary Operator. TODO: desc. """
def __init__(self, left, right=None):
self._inner = left
@@ -70,7 +67,7 @@ class UnaryOp(Expr):
class Reducible(Expr):
- """ Reducible Expression
+ """ Reducible Expression.
Algebraic expression that can be written in terms of other (existing)
expression objects, i.e. can be reduced to another expression made of
@@ -102,7 +99,7 @@ class Reducible(Expr):
class PolyRingExpr(Expr):
- r""" Endows with the algebraic structure of a polynomial ring.
+ r""" Expression with the algebraic behaviour of a polynomial ring.
This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the
polynomials are scalars.
@@ -138,7 +135,11 @@ class PolyRingExpr(Expr):
@classmethod
def make_combinations(cls, variables: Iterable[PolyVar], max_degree: int) -> Iterable[PolyRingExpr]:
- """ Make a list of """
+ """ Make combinations of terms.
+
+ For example given :math:`x, y` and *max_degree* of 2 generates
+ :math:`x^2, xy, x, y^2, y, 1`.
+ """
# Ignore typecheck because dataclassabc has no typing stub
vars_and_const = chain(variables, (PolyConst(1),)) # type: ignore[call-arg]
for comb in combinations_with_replacement(vars_and_const, max_degree):
@@ -240,7 +241,7 @@ class PolyParam(Param, PolyRingExpr):
raise MissingParameters("Cannot construct representation because "
f"value for parameter {self} was not given.")
- return PolyConst(state.parameters[self]).to_repr(repr_type, state)
+ return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]
class PolyAdd(BinaryOp, PolyRingExpr):
@@ -308,7 +309,7 @@ class PolyMul(BinaryOp, PolyRingExpr):
class PolyExp(BinaryOp, PolyRingExpr, Reducible):
- """ Generic exponentiation (no type check). """
+ """ Exponentiation operator between scalar polynomials. """
@property # type: ignore[override]
def right(self) -> Const: # type: ignore[override]
@@ -338,7 +339,7 @@ class PolyExp(BinaryOp, PolyRingExpr, Reducible):
class PolyPartialDiff(UnaryOp, PolyRingExpr):
- """ Partial differentiation of scalar polynomials """
+ """ Partial differentiation of scalar polynomials. """
def __init__(self, inner: Expr, with_respect_to: Var):
UnaryOp.__init__(self, inner)
@@ -382,8 +383,8 @@ class RationalFieldExpr(Expr):
# ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹
class MatrixExpr(Expr):
- r""" Endows with the algebraic structure of a matrix ring and / or module
- depending on the shape.
+ r""" Expression with the algebraic properties of a matrix ring and / or
+ module (depending on the shape).
We denote with :math:`R` a polynomial ring.
@@ -447,6 +448,7 @@ class MatrixExpr(Expr):
def transpose(self) -> MatrixExpr:
""" Matrix transposition. """
+ raise NotImplementedError
return MatTranspose(self)
@property
@@ -461,9 +463,7 @@ class MatrixExpr(Expr):
@dataclassabc(frozen=True)
class MatConst(Const, MatrixExpr):
- """
- A matrix constant
- """
+ """ Matrix constant. TODO: desc. """
value: Sequence[Sequence[Number]] # Row major, overloads Const.value
shape: Shape # overloads Expr.shape
name: str = "" # overloads Leaf.name
@@ -471,9 +471,9 @@ class MatConst(Const, MatrixExpr):
def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
r = repr_type(self.shape)
- for r, row in enumerate(self.value):
- for c, val in enumerate(row):
- r.set(MatrixIndex(row=r, col=c), PolyIndex.constant(), val)
+ for i, row in enumerate(self.value):
+ for j, val in enumerate(row):
+ r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val)
return r, state
@@ -490,16 +490,15 @@ T = TypeVar("T", bound=Var)
@dataclassabc(frozen=True)
class MatVar(Var, MatrixExpr):
- """
- Matrix polynomial variable
- """
+ """ Matrix of polynomial variables. TODO: desc """
name: str # overloads Leaf.name
shape: Shape # overloads Expr.shape
+ # TODO: review this API, can be moved elsewhere?
def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]:
for row in range(self.shape.rows):
for col in range(self.shape.cols):
- var = scalar_var_type(name=f"{self.name}_[{row},{col}]")
+ var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg]
entry = MatrixIndex(row, col)
yield entry, var
@@ -520,9 +519,7 @@ class MatVar(Var, MatrixExpr):
@dataclassabc(frozen=True)
class MatParam(Param, MatrixExpr):
- """
- Matrix parameter
- """
+ """ Matrix parameter. TODO: desc. """
name: str
shape: Shape
@@ -532,7 +529,8 @@ class MatParam(Param, MatrixExpr):
f"value for parameter {self} was not given.")
# FIXME: add conversion to scalar variables
- return MatConst(state.parameters[self]).to_repr(repr_type, state)
+ # Ignore typecheck because dataclassabc has not type stub
+ return MatConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]
def __repr__(self) -> str:
return self.name