diff options
Diffstat (limited to '')
-rw-r--r-- | mdpoly/abc.py | 12 | ||||
-rw-r--r-- | mdpoly/algebra.py | 54 |
2 files changed, 30 insertions, 36 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 245f3a2..989fa7f 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -18,11 +18,7 @@ from dataclassabc import dataclassabc from hashlib import sha256 -# Forward declaration of Repr class for TypeVar -# Why is this shit necessary? Feels like writing C++ -class Repr: ... - -ReprT = TypeVar("ReprT", bound=Repr) +ReprT = TypeVar("ReprT", bound="Repr") # ┏━╸╻ ╻┏━┓┏━┓┏━╸┏━┓┏━┓╻┏━┓┏┓╻┏━┓ # ┣╸ ┏╋┛┣━┛┣┳┛┣╸ ┗━┓┗━┓┃┃ ┃┃┗┫┗━┓ @@ -109,7 +105,7 @@ class Expr(ABC): @abstractmethod def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - """ Convert to a concrete representation + """ Convert to a concrete representation. To convert an abstract object from the expression tree, we ... TODO: finish. @@ -125,13 +121,13 @@ class Expr(ABC): return h.digest() def children(self) -> Sequence[Expr]: - """ Iterate over the two nodes """ + """ Iterate over the two nodes. """ if self.is_leaf: return tuple() return self.left, self.right def leaves(self) -> Iterable[Expr]: - """ Returns the leaves of the tree. This is done recursively and in + r""" Returns the leaves of the tree. This is done recursively and in :math:`\mathcal{O}(n)`.""" if self.left.is_leaf: yield self.left diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 38373c0..04a61ad 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -1,7 +1,7 @@ """ Algebraic Structures for Expressions """ from __future__ import annotations -from .abc import Algebra, Expr, Repr, Nothing, Const, Var, Param +from .abc import Algebra, Expr, ReprT, Nothing, Const, Var, Param from .errors import AlgebraicError, InvalidShape, MissingParameters from .state import State from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number @@ -15,11 +15,8 @@ from dataclassabc import dataclassabc import operator -ReprT = TypeVar("ReprT", bound=Repr) - - class BinaryOp(Expr): - """ Binary Operator """ + """ Binary Operator. TODO: desc """ def __init__(self, left, right): self._left = left self._right = right @@ -42,7 +39,7 @@ class BinaryOp(Expr): class UnaryOp(Expr): - """ Unary Operator """ + """ Unary Operator. TODO: desc. """ def __init__(self, left, right=None): self._inner = left @@ -70,7 +67,7 @@ class UnaryOp(Expr): class Reducible(Expr): - """ Reducible Expression + """ Reducible Expression. Algebraic expression that can be written in terms of other (existing) expression objects, i.e. can be reduced to another expression made of @@ -102,7 +99,7 @@ class Reducible(Expr): class PolyRingExpr(Expr): - r""" Endows with the algebraic structure of a polynomial ring. + r""" Expression with the algebraic behaviour of a polynomial ring. This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the polynomials are scalars. @@ -138,7 +135,11 @@ class PolyRingExpr(Expr): @classmethod def make_combinations(cls, variables: Iterable[PolyVar], max_degree: int) -> Iterable[PolyRingExpr]: - """ Make a list of """ + """ Make combinations of terms. + + For example given :math:`x, y` and *max_degree* of 2 generates + :math:`x^2, xy, x, y^2, y, 1`. + """ # Ignore typecheck because dataclassabc has no typing stub vars_and_const = chain(variables, (PolyConst(1),)) # type: ignore[call-arg] for comb in combinations_with_replacement(vars_and_const, max_degree): @@ -240,7 +241,7 @@ class PolyParam(Param, PolyRingExpr): raise MissingParameters("Cannot construct representation because " f"value for parameter {self} was not given.") - return PolyConst(state.parameters[self]).to_repr(repr_type, state) + return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg] class PolyAdd(BinaryOp, PolyRingExpr): @@ -308,7 +309,7 @@ class PolyMul(BinaryOp, PolyRingExpr): class PolyExp(BinaryOp, PolyRingExpr, Reducible): - """ Generic exponentiation (no type check). """ + """ Exponentiation operator between scalar polynomials. """ @property # type: ignore[override] def right(self) -> Const: # type: ignore[override] @@ -338,7 +339,7 @@ class PolyExp(BinaryOp, PolyRingExpr, Reducible): class PolyPartialDiff(UnaryOp, PolyRingExpr): - """ Partial differentiation of scalar polynomials """ + """ Partial differentiation of scalar polynomials. """ def __init__(self, inner: Expr, with_respect_to: Var): UnaryOp.__init__(self, inner) @@ -382,8 +383,8 @@ class RationalFieldExpr(Expr): # ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ class MatrixExpr(Expr): - r""" Endows with the algebraic structure of a matrix ring and / or module - depending on the shape. + r""" Expression with the algebraic properties of a matrix ring and / or + module (depending on the shape). We denote with :math:`R` a polynomial ring. @@ -447,6 +448,7 @@ class MatrixExpr(Expr): def transpose(self) -> MatrixExpr: """ Matrix transposition. """ + raise NotImplementedError return MatTranspose(self) @property @@ -461,9 +463,7 @@ class MatrixExpr(Expr): @dataclassabc(frozen=True) class MatConst(Const, MatrixExpr): - """ - A matrix constant - """ + """ Matrix constant. TODO: desc. """ value: Sequence[Sequence[Number]] # Row major, overloads Const.value shape: Shape # overloads Expr.shape name: str = "" # overloads Leaf.name @@ -471,9 +471,9 @@ class MatConst(Const, MatrixExpr): def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: r = repr_type(self.shape) - for r, row in enumerate(self.value): - for c, val in enumerate(row): - r.set(MatrixIndex(row=r, col=c), PolyIndex.constant(), val) + for i, row in enumerate(self.value): + for j, val in enumerate(row): + r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val) return r, state @@ -490,16 +490,15 @@ T = TypeVar("T", bound=Var) @dataclassabc(frozen=True) class MatVar(Var, MatrixExpr): - """ - Matrix polynomial variable - """ + """ Matrix of polynomial variables. TODO: desc """ name: str # overloads Leaf.name shape: Shape # overloads Expr.shape + # TODO: review this API, can be moved elsewhere? def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]: for row in range(self.shape.rows): for col in range(self.shape.cols): - var = scalar_var_type(name=f"{self.name}_[{row},{col}]") + var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg] entry = MatrixIndex(row, col) yield entry, var @@ -520,9 +519,7 @@ class MatVar(Var, MatrixExpr): @dataclassabc(frozen=True) class MatParam(Param, MatrixExpr): - """ - Matrix parameter - """ + """ Matrix parameter. TODO: desc. """ name: str shape: Shape @@ -532,7 +529,8 @@ class MatParam(Param, MatrixExpr): f"value for parameter {self} was not given.") # FIXME: add conversion to scalar variables - return MatConst(state.parameters[self]).to_repr(repr_type, state) + # Ignore typecheck because dataclassabc has not type stub + return MatConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg] def __repr__(self) -> str: return self.name |