diff options
author | Nao Pross <np@0hm.ch> | 2024-03-18 19:02:44 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-18 19:02:44 +0100 |
commit | 8676dc34530be9669cd3dfc5bd5a54cb8f9ac482 (patch) | |
tree | 9036846faeec4e12e7666daa4e151c205dd8eb15 | |
parent | Rename PolyRingExpr to just PolyExpr (diff) | |
download | mdpoly-8676dc34530be9669cd3dfc5bd5a54cb8f9ac482.tar.gz mdpoly-8676dc34530be9669cd3dfc5bd5a54cb8f9ac482.zip |
Move operations (operator overloading etc) outside of expression class
The new class WithOps (TODO: better name) wraps around an expression
tree and provides operator overloading and convenient methods. This
requires a big structural change hence most things are now broken.
-rw-r--r-- | mdpoly/__init__.py | 23 | ||||
-rw-r--r-- | mdpoly/abc.py | 112 | ||||
-rw-r--r-- | mdpoly/expressions.py | 235 | ||||
-rw-r--r-- | mdpoly/expressions/__init__.py | 66 | ||||
-rw-r--r-- | mdpoly/expressions/matrix.py | 172 | ||||
-rw-r--r-- | mdpoly/expressions/poly.py | 170 | ||||
-rw-r--r-- | mdpoly/operations/__init__.py | 73 | ||||
-rw-r--r-- | mdpoly/operations/add.py | 18 | ||||
-rw-r--r-- | mdpoly/operations/derivative.py | 16 | ||||
-rw-r--r-- | mdpoly/operations/exp.py | 8 | ||||
-rw-r--r-- | mdpoly/operations/mul.py | 14 | ||||
-rw-r--r-- | mdpoly/operations/transpose.py | 3 |
12 files changed, 356 insertions, 554 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py index 8334d87..26025ab 100644 --- a/mdpoly/__init__.py +++ b/mdpoly/__init__.py @@ -107,8 +107,9 @@ TODO: data structure(s) that represent the polynomial from .index import (Shape as _Shape) -from .expressions.poly import (PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam) -from .expressions.matrix import (MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam) +from .expressions import (WithOps as _WithOps, + PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam, + MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam) from .state import State as _State @@ -138,27 +139,33 @@ class FromHelpers: yield from map(cls, names) -class Constant(_PolyConst, FromHelpers): +class Constant(_WithOps, FromHelpers): """ Constant values """ + def __init__(self, *args, **kwargs): + _WithOps.__init__(self, expr=_PolyConst(*args, **kwargs)) -class Variable(_PolyVar, FromHelpers): +class Variable(_WithOps, FromHelpers): """ Polynomial Variable """ + def __init__(self, *args, **kwargs): + _WithOps.__init__(self, expr=_PolyVar(*args, **kwargs)) -class Parameter(_PolyParam, FromHelpers): +class Parameter(_WithOps, FromHelpers): """ Parameter that can be substituted """ + def __init__(self, *args, **kwargs): + _WithOps.__init__(self, expr=_PolyParam(*args, **kwargs)) -class MatrixConstant(_MatConst): +class MatrixConstant(_WithOps): """ Matrix constant """ -class MatrixVariable(_MatVar): +class MatrixVariable(_WithOps): """ Matrix Polynomial Variable """ -class MatrixParameter(_MatParam): +class MatrixParameter(_WithOps): """ Matrix Parameter """ diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 882634f..1f15361 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import TYPE_CHECKING -from typing import Self, Type, TypeVar, Generic, Any, Iterable, Sequence +from typing import Self, Type, TypeVar, Generic, Iterable, Sequence from enum import Enum, auto from copy import copy from abc import ABC, abstractmethod @@ -10,7 +10,6 @@ from hashlib import sha256 from dataclassabc import dataclassabc from .constants import NUMERICS_EPS -from .errors import AlgebraicError from .index import Shape from .util import iszero @@ -24,13 +23,6 @@ if TYPE_CHECKING: # ┗━╸╹ ╹╹ ╹┗╸┗━╸┗━┛┗━┛╹┗━┛╹ ╹┗━┛ -class Algebra(Enum): - """ Types of algebras. """ - none = auto() - poly_ring = auto() - matrix_ring = auto() - - class Expr(ABC): """ Merkle binary tree to represent a mathematical expression. """ @@ -65,19 +57,6 @@ class Expr(ABC): @property @abstractmethod - def algebra(self) -> Algebra: - """ Specifies to which algebra belongs the expression. - - This is used to provide nicer error messages and to avoid accidental - mistakes (like adding a scalar and a matrix) which are hard to debug - as expressions are lazily evalated (i.e. without this the exception for - adding a scalar to a matrix does not occurr at the line where the two - are added but rather where :py:meth:`mdpoly.abc.Expr.to_repr` is - called). - """ - - @property - @abstractmethod def shape(self) -> Shape: # TODO: Test on very large expressions. If there is a performance hit # change to functools.cached_property. @@ -198,78 +177,6 @@ class Expr(ABC): return pivot - # --- Private methods --- - - @staticmethod - def _assert_same_algebra(left: Expr, right: Expr) -> None: - if not isinstance(left, Expr) or not isinstance(right, Expr): - return - - if left.algebra != right.algebra: - raise AlgebraicError("Cannot perform algebraic operation between " - f"{left} and {right} because they have different " - f"algebras {left.algebra} and {right.algebra}.") - - @staticmethod - def _wrap(if_type: type, wrapper_type: type, obj: Any, *args, **kwargs) -> Expr: - """ Wrap non-expr objects. - - Suppose ``x`` is of type *Expr*, then we would like to be able to do - things like ``x + 1``, so this function can be called in operator - overloadings to wrap the 1 into a :py:class:`mdpoly.leaves.Const`. If - ``obj`` is already of type *Expr*, this function does nothing. The - arguments *args* and *kwargs* are forwarded to the constructor of - *wrapper_type*. - """ - # Do not wrap if is alreay an expression - if isinstance(obj, Expr): - return obj - - if not isinstance(obj, if_type): - raise TypeError(f"Cannot wrap {obj} with type {wrapper_type} because " - f"it is not of type {if_type}.") - - return wrapper_type(obj, *args, **kwargs) - - # --- Operator overloading --- - - def __add__(self, other: Any) -> Self: - raise NotImplementedError - - def __radd__(self, other: Any) -> Self: - raise NotImplementedError - - def __sub__(self, other: Any) -> Self: - raise NotImplementedError - - def __rsub__(self, other: Any) -> Self: - raise NotImplementedError - - def __neg__(self) -> Self: - raise NotImplementedError - - def __mul__(self, other: Any) -> Self: - raise NotImplementedError - - def __rmul__(self, other: Any) -> Self: - raise NotImplementedError - - def __pow__(self, other: Any) -> Self: - raise NotImplementedError - - def __matmul__(self, other: Any) -> Self: - raise NotImplementedError - - def __rmatmul__(self, other: Any) -> Self: - raise NotImplementedError - - def __truediv__(self, other: Any) -> Self: - raise NotImplementedError - - def __rtruediv__(self, other: Any) -> Self: - raise NotImplementedError - - class Leaf(Expr): """ Leaf of the binary tree. """ @@ -316,7 +223,6 @@ class Leaf(Expr): def subtree_hash(self) -> bytes: h = sha256() h.update(self.__class__.__qualname__.encode("utf-8")) - h.update(bytes(self.algebra.value)) h.update(self.name.encode("utf-8")) # type: ignore h.update(bytes(self.shape)) return h.digest() @@ -326,7 +232,10 @@ T = TypeVar("T") class Const(Leaf, Generic[T]): - """ Leaf to represent a constant. TODO: desc. """ + """ Leaf to represent a constant. + + TODO: desc. + """ @property @abstractmethod @@ -341,11 +250,17 @@ class Const(Leaf, Generic[T]): class Var(Leaf): - """ Leaf to reprsent a Variable. TODO: desc """ + """ Leaf to reprsent a Variable. + + TODO: desc + """ class Param(Leaf): - """ Parameter. TODO: desc """ + """ Parameter. + + TODO: desc + """ @dataclassabc(frozen=True) @@ -355,7 +270,6 @@ class Nothing(Leaf): """ name: str = "." shape: Shape = Shape(0, 0) - algebra: Algebra = Algebra.none def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: raise ValueError("Nothing cannot be represented.") diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py new file mode 100644 index 0000000..c34a0aa --- /dev/null +++ b/mdpoly/expressions.py @@ -0,0 +1,235 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from dataclassabc import dataclassabc +from dataclasses import dataclass +from functools import wraps +from typing import Type, TypeVar, Iterable, Callable, Sequence, Any, Self, cast + +from .abc import Expr, Var, Const, Param +from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex +from .errors import MissingParameters + +if TYPE_CHECKING: + from .abc import ReprT + from .index import Shape, Number + from .state import State + + +# ┏┳┓┏━┓╺┳╸┏━┓╻┏━╸┏━╸┏━┓ +# ┃┃┃┣━┫ ┃ ┣┳┛┃┃ ┣╸ ┗━┓ +# ╹ ╹╹ ╹ ╹ ╹┗╸╹┗━╸┗━╸┗━┛ + + +@dataclassabc(frozen=True) +class MatConst(Const): + """ Matrix constant. TODO: desc. """ + value: Sequence[Sequence[Number]] # Row major, overloads Const.value + shape: Shape # overloads Expr.shape + name: str = "" # overloads Leaf.name + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + r = repr_type(self.shape) + + for i, row in enumerate(self.value): + for j, val in enumerate(row): + r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val) + + return r, state + + + def __str__(self) -> str: + if not self.name: + return repr(self.value) + + return self.name + + +T = TypeVar("T", bound=Var) + +@dataclassabc(frozen=True) +class MatVar(Var): + """ Matrix of polynomial variables. TODO: desc """ + name: str # overloads Leaf.name + shape: Shape # overloads Expr.shape + + # TODO: review this API, can be moved elsewhere? + def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]: + for row in range(self.shape.rows): + for col in range(self.shape.cols): + var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg] + entry = MatrixIndex(row, col) + + yield entry, var + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + r = repr_type(self.shape) + + # FIXME: do not hardcode scalar type + for entry, var in self.to_scalars(PolyVar): + idx = PolyVarIndex.from_var(var, state), # important comma! + r.set(entry, PolyIndex(idx), 1) + + return r, state + + def __str__(self) -> str: + return self.name + + +@dataclassabc(frozen=True) +class MatParam(Param): + """ Matrix parameter. TODO: desc. """ + name: str + shape: Shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + if self not in state.parameters: + raise MissingParameters("Cannot construct representation because " + f"value for parameter {self} was not given.") + + # FIXME: add conversion to scalar variables + # Ignore typecheck because dataclassabc has not type stub + return MatConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg] + + def __str__(self) -> str: + return self.name + + +# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓ +# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┗━┓ +# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸┗━┛ + + +@dataclassabc(frozen=True) +class PolyVar(Var): + """ Variable TODO: desc """ + name: str # overloads Leaf.name + shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + r = repr_type(self.shape) + idx = PolyVarIndex.from_var(self, state), # important comma! + r.set(MatrixIndex.scalar(), PolyIndex(idx), 1) + return r, state + + +@dataclassabc(frozen=True) +class PolyConst(Const): + """ Constant TODO: desc """ + value: Number # overloads Const.value + name: str = "" # overloads Leaf.name + shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + r = repr_type(self.shape) + r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value) + return r, state + + +@dataclassabc(frozen=True) +class PolyParam(Param): + """ Polynomial parameter TODO: desc """ + name: str # overloads Leaf.name + shape: Shape = Shape.scalar() # overloads PolyExpr.shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + if self not in state.parameters: + raise MissingParameters("Cannot construct representation because " + f"value for parameter {self} was not given.") + + return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg] + + +# ┏━┓┏━┓┏━╸┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓ +# ┃ ┃┣━┛┣╸ ┣┳┛┣━┫ ┃ ┃┃ ┃┃┗┫┗━┓ +# ┗━┛╹ ┗━╸╹┗╸╹ ╹ ╹ ╹┗━┛╹ ╹┗━┛ + + +@dataclass +class WithOps: + """ Monadic wrapper around :py:class:`mdpoly.abc.Expr` that adds operator + overloading and operations to expression objects. """ + expr: Expr + + # -- Monadic operations -- + + @staticmethod + def map(fn: Callable[[Expr], Expr]) -> Callable[[WithOps], WithOps]: + """ Map in the functional programming sense. + + Converts a function like `(Expr -> Expr)` into + `(WithOps<Expr> -> WithOps<Expr>)`. + """ + @wraps(fn) + def wrapper(e: WithOps) -> WithOps: + return WithOps(expr=fn(e.expr)) + return wrapper + + @staticmethod + def zip(fn: Callable[..., Expr]) -> Callable[..., WithOps]: + """ Zip, in the functional programming sense. + + Converts a function like `(Expr -> Expr -> ... -> Expr)` into + `(WithOps<Expr> -> WithOps<Expr> -> ... -> WithOps<Expr>)`. + """ + @wraps(fn) + def wrapper(*args: WithOps) -> WithOps: + return WithOps(expr=fn(*(arg.expr for arg in args))) + return wrapper + + @staticmethod + def bind(fn: Callable[[Expr], WithOps]) -> Callable[[WithOps], WithOps]: + """ Bind in the functional programming sense. Also sometimes known as + flatmap. + + Converts a funciton like `(Expr -> WithOps)` into `(WithOps -> WithOps)` + so that it can be composed with other functions. + """ + @wraps(fn) + def wrapper(arg: WithOps) -> WithOps: + return fn(arg.expr) + return wrapper + + # -- Operator overloading --- + + def __add__(self, other: WithOps | Number) -> WithOps: + if isinstance(other, Number): + other = WithOps(expr=PolyConst(value=other)) + + return add(self, other) + + def __radd__(self, other: WithOps | Number) -> WithOps: + if isinstance(other, Number): + other = WithOps(expr=PolyConst(value=other)) + + return add(other, self) + + def __sub__(self, other: Self) -> Self: + raise NotImplementedError + + def __rsub__(self, other: Self) -> Self: + raise NotImplementedError + + def __neg__(self) -> Self: + raise NotImplementedError + + def __mul__(self, other: Any) -> Self: + raise NotImplementedError + + def __rmul__(self, other: Any) -> Self: + raise NotImplementedError + + def __pow__(self, other: Any) -> Self: + raise NotImplementedError + + def __matmul__(self, other: Any) -> Self: + raise NotImplementedError + + def __rmatmul__(self, other: Any) -> Self: + raise NotImplementedError + + def __truediv__(self, other: Any) -> Self: + raise NotImplementedError + + def __rtruediv__(self, other: Any) -> Self: + raise NotImplementedError diff --git a/mdpoly/expressions/__init__.py b/mdpoly/expressions/__init__.py deleted file mode 100644 index 2efe342..0000000 --- a/mdpoly/expressions/__init__.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from abc import abstractmethod -from typing import Type -from dataclasses import field -from dataclassabc import dataclassabc - -from ..abc import Expr, Nothing - -if TYPE_CHECKING: - from ..abc import ReprT - from ..state import State - - -@dataclassabc(eq=False) -class BinaryOp(Expr): - left: Expr = field() - right: Expr = field() - - -@dataclassabc(eq=False) -class UnaryOp(Expr): - right: Expr = field() - - @property - def inner(self) -> Expr: - """ Inner expression on which the operator is acting, alias for right. """ - return self.right - - @property - def left(self) -> Expr: - return Nothing() - - @left.setter - def left(self, left) -> None: - if not isinstance(left, Nothing): - raise ValueError("Cannot set left of left-acting unary operator " - "to something that is not of type Nothing.") - - -class Reducible(Expr): - """ Reducible Expression. - - Algebraic expression that can be written in terms of other (existing) - expression objects, i.e. can be reduced to another expression made of - simpler operations. For example subtraction can be written in term of - addition and multiplication. - """ - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ - return self.reduce().to_repr(repr_type, state) - - @abstractmethod - def reduce(self) -> Expr: - """ Reduce the expression to its basic elements """ - - -# TODO: review this idea before implementing -# class Simplifiable(Expr): -# """ Simplifiable Expression. """ -# -# @abstractmethod -# def simplify(self) -> Expr: -# """ Simplify the expression """ diff --git a/mdpoly/expressions/matrix.py b/mdpoly/expressions/matrix.py deleted file mode 100644 index cd38383..0000000 --- a/mdpoly/expressions/matrix.py +++ /dev/null @@ -1,172 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from typing import Type, TypeVar, Iterable, Sequence -from dataclassabc import dataclassabc - -from .poly import PolyVar, PolyConst -from ..abc import Expr, Var, Const, Param, Algebra -from ..index import MatrixIndex, PolyIndex, PolyVarIndex -from ..errors import MissingParameters - -if TYPE_CHECKING: - from ..abc import ReprT - from ..index import Shape, Number - from ..state import State - - -class MatrixExpr(Expr): - r""" Expression with the algebraic properties of a matrix ring and / or - module (depending on the shape). - - We denote with :math:`R` a polynomial ring. - - - If the shape is square, like :math:`(n, n)` then this is :math:`M_n(R)` - the ring of matrices over :math:`R`. - - - If the shape is something else like a row or column (:math:`(m, p)`, - :math:`(1, n)` or :math:`(n, 1)`) this is a module, i.e. an algebra with - addition and scalar multiplication, where the "scalars" come from - :math:`R`. - - Furthermore some operators that are usually expected from matrices are - already included (eg. transposition). - """ - - @property - def algebra(self) -> Algebra: - return Algebra.matrix_ring - - def __add__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, MatConst, other) - return operations.add.MatAdd(self, other) - - def __sub__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, MatConst, other) - return operations.add.MatSub(self, other) - - def __rsub__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, MatConst, other) - return operations.add.MatSub(other, self) - - def __neg__(self): - # FIXME: Create PolyNeg? - return operations.mul.MatScalarMul(PolyConst(-1), self) - - def __mul__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, MatConst, other) - - # TODO: case distiction based on shapes - return operations.mul.MatScalarMul(other, self) - - - def __rmul__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, MatConst, other) - return operations.mul.MatScalarMul(other, self) - - def __matmul__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, MatConst, other) - return operations.MatMul(self, other) - - def __rmatmul(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, MatConst, other) - return operations.mul.MatMul(other, self) - - def __truediv__(self, scalar): - scalar = self._wrap_if_constant(scalar) - raise NotImplementedError - - def transpose(self) -> MatrixExpr: - """ Matrix transposition. """ - raise NotImplementedError - return operations.transpose.MatTranspose(self) - - @property - def T(self) -> MatrixExpr: - """ Shorthand for :py:meth:`mdpoly.expressions.matrix.MatrixExpr.transpose`. """ - return self.transpose() - - def to_scalar(self, scalar_type: type): - """ Convert to a scalar expression. """ - raise NotImplementedError - - -@dataclassabc(frozen=True) -class MatConst(Const, MatrixExpr): - """ Matrix constant. TODO: desc. """ - value: Sequence[Sequence[Number]] # Row major, overloads Const.value - shape: Shape # overloads Expr.shape - name: str = "" # overloads Leaf.name - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - r = repr_type(self.shape) - - for i, row in enumerate(self.value): - for j, val in enumerate(row): - r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val) - - return r, state - - - def __str__(self) -> str: - if not self.name: - return repr(self.value) - - return self.name - - -T = TypeVar("T", bound=Var) - -@dataclassabc(frozen=True) -class MatVar(Var, MatrixExpr): - """ Matrix of polynomial variables. TODO: desc """ - name: str # overloads Leaf.name - shape: Shape # overloads Expr.shape - - # TODO: review this API, can be moved elsewhere? - def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]: - for row in range(self.shape.rows): - for col in range(self.shape.cols): - var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg] - entry = MatrixIndex(row, col) - - yield entry, var - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - r = repr_type(self.shape) - - # FIXME: do not hardcode scalar type - for entry, var in self.to_scalars(PolyVar): - idx = PolyVarIndex.from_var(var, state), # important comma! - r.set(entry, PolyIndex(idx), 1) - - return r, state - - def __str__(self) -> str: - return self.name - - -@dataclassabc(frozen=True) -class MatParam(Param, MatrixExpr): - """ Matrix parameter. TODO: desc. """ - name: str - shape: Shape - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - if self not in state.parameters: - raise MissingParameters("Cannot construct representation because " - f"value for parameter {self} was not given.") - - # FIXME: add conversion to scalar variables - # Ignore typecheck because dataclassabc has not type stub - return MatConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg] - - def __str__(self) -> str: - return self.name diff --git a/mdpoly/expressions/poly.py b/mdpoly/expressions/poly.py deleted file mode 100644 index 623e89e..0000000 --- a/mdpoly/expressions/poly.py +++ /dev/null @@ -1,170 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from typing import Type, Iterable, cast -from functools import reduce -from itertools import chain, combinations_with_replacement -from dataclassabc import dataclassabc -from operator import mul as opmul - -from ..abc import Expr, Var, Const, Param, Algebra, ReprT -from ..index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number -from ..errors import AlgebraicError, MissingParameters, InvalidShape - -from .. import operations - -if TYPE_CHECKING: - from ..index import Number - from ..state import State - - -class PolyExpr(Expr): - r""" Expression with the algebraic behaviour of a polynomial ring. - - This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the - polynomials are scalars. - """ - - # -- Properties --- - - @property - def algebra(self) -> Algebra: - return Algebra.poly_ring - - @property - def shape(self) -> Shape: - """ See :py:meth:`mdpoly.abc.Expr.shape`. """ - if self.left.shape != self.right.shape: - raise InvalidShape(f"Cannot perform operation {repr(self)} with " - f"shapes {self.left.shape} and {self.right.shape}.") - return self.left.shape - - # --- Helper methods for construction --- - - @classmethod - def from_powers(cls, variable: PolyExpr, exponents: Iterable[int]) -> PolyExpr: - """ Given a variable, say :math:`x`, and a list of exponents ``[0, 3, 5]`` - returns a polynomial :math:`x^5 + x^3 + 1`. - """ - nonzero_exponents = filter(lambda e: e != 0, exponents) - # Ignore typecheck because dataclassabc has no typing stub - constant_term = PolyConst(1 if 0 in exponents else 0) # type: ignore[call-arg] - # FIXME: remove annoying 0 - return sum((variable ** e for e in nonzero_exponents), constant_term) - - - @classmethod - def make_combinations(cls, variables: Iterable[PolyVar], max_degree: int) -> Iterable[PolyExpr]: - """ Make combinations of terms. - - For example given :math:`x, y` and *max_degree* of 2 generates - :math:`x^2, xy, x, y^2, y, 1`. - """ - # Ignore typecheck because dataclassabc has no typing stub - vars_and_const = chain(variables, (PolyConst(1),)) # type: ignore[call-arg] - for comb in combinations_with_replacement(vars_and_const, max_degree): - yield cast(PolyExpr, reduce(opmul, comb)) - - # --- Operator Overloading --- - - def __add__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, PolyConst, other) - return operations.add.PolyAdd(self, other) - - def __radd__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, PolyConst, other) - return operations.add.PolyAdd(other, self) - - def __sub__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, PolyConst, other) - return operations.add.PolySub(self, other) - - def __rsub__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, PolyConst, other) - return operations.add.PolyAdd(other, self) - - def __neg__(self): - # FIXME: Create PolyNeg? - return operations.mul.PolyMul(PolyConst(-1), self) - - def __mul__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, PolyConst, other) - return operations.mul.PolyMul(self, other) - - def __rmul__(self, other): - self._assert_same_algebra(self, other) - other = self._wrap(Number, PolyConst, other) - return operations.mul.PolyMul(other, self) - - def __matmul__(self, other): - raise AlgebraicError("Cannot perform matrix multiplication in polynomial ring (they are scalars).") - - def __rmatmul__(self, other): - self.__rmatmul__(other) - - def __truediv__(self, other): - other = self._wrap(Number, PolyConst, other) - if not isinstance(other, PolyConst | PolyParam): - raise AlgebraicError("Cannot divide by variables in polynomial ring.") - - return operations.mul.PolyMul(self, other) - - def __rtruediv__(self, other): - raise AlgebraicError("Cannot perform right division in polynomial ring.") - - def __pow__(self, other): - other = self._wrap(Number, PolyConst, other) - if not isinstance(other, PolyConst | PolyParam): - raise AlgebraicError(f"Cannot raise to powers of type {type(other)} in " - "polynomial ring. Only constants and parameters are allowed.") - return operations.exp.PolyExp(left=self, right=other) - - # -- Other mathematical operations --- - - def diff(self, wrt: PolyVar) -> operations.derivative.PolyPartialDiff: - return operations.derivative.PolyPartialDiff(right=self, wrt=wrt) - - -@dataclassabc(frozen=True) -class PolyVar(Var, PolyExpr): - """ Variable TODO: desc """ - name: str # overloads Leaf.name - shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - r = repr_type(self.shape) - idx = PolyVarIndex.from_var(self, state), # important comma! - r.set(MatrixIndex.scalar(), PolyIndex(idx), 1) - return r, state - - -@dataclassabc(frozen=True) -class PolyConst(Const, PolyExpr): - """ Constant TODO: desc """ - value: Number # overloads Const.value - name: str = "" # overloads Leaf.name - shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - r = repr_type(self.shape) - r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value) - return r, state - - -@dataclassabc(frozen=True) -class PolyParam(Param, PolyExpr): - """ Polynomial parameter TODO: desc """ - name: str # overloads Leaf.name - shape: Shape = Shape.scalar() # overloads PolyExpr.shape - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - if self not in state.parameters: - raise MissingParameters("Cannot construct representation because " - f"value for parameter {self} was not given.") - - return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg] diff --git a/mdpoly/operations/__init__.py b/mdpoly/operations/__init__.py index e69de29..752358c 100644 --- a/mdpoly/operations/__init__.py +++ b/mdpoly/operations/__init__.py @@ -0,0 +1,73 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from typing import Type +from abc import abstractmethod +from dataclasses import field +from dataclassabc import dataclassabc + +from ..abc import ReprT, Expr, Nothing + +if TYPE_CHECKING: + from ..state import State + + +@dataclassabc(eq=False) +class BinaryOp(Expr): + """ Binary operator. + + TODO: desc + """ + left: Expr = field() + right: Expr = field() + + +@dataclassabc(eq=False) +class UnaryOp(Expr): + """ Unary operator. + + TODO: desc + """ + right: Expr = field() + + @property + def inner(self) -> Expr: + """ Inner expression on which the operator is acting, alias for right. """ + return self.right + + @property + def left(self) -> Expr: + return Nothing() + + @left.setter + def left(self, left) -> None: + if not isinstance(left, Nothing): + raise ValueError("Cannot set left of left-acting unary operator " + "to something that is not of type Nothing.") + + +class Reducible(Expr): + """ Reducible Expression. + + Algebraic expression that can be written in terms of other (existing) + expression objects, i.e. can be reduced to another expression made of + simpler operations. For example subtraction can be written in term of + addition and multiplication. + """ + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ + return self.reduce().to_repr(repr_type, state) + + @abstractmethod + def reduce(self) -> Expr: + """ Reduce the expression to its basic elements """ + + +# TODO: review this idea before implementing +# class Simplifiable(Expr): +# """ Simplifiable Expression. """ +# +# @abstractmethod +# def simplify(self) -> Expr: +# """ Simplify the expression """ diff --git a/mdpoly/operations/add.py b/mdpoly/operations/add.py index b0174ef..30ad3b4 100644 --- a/mdpoly/operations/add.py +++ b/mdpoly/operations/add.py @@ -7,9 +7,7 @@ from dataclassabc import dataclassabc from ..abc import Expr from ..errors import AlgebraicError -from ..expressions import BinaryOp, Reducible -from ..expressions.matrix import MatrixExpr -from ..expressions.poly import PolyExpr +from . import BinaryOp, Reducible if TYPE_CHECKING: from ..abc import ReprT @@ -18,7 +16,7 @@ if TYPE_CHECKING: @dataclassabc(eq=False) -class MatAdd(BinaryOp, MatrixExpr): +class MatAdd(BinaryOp): """ Addition operator between matrices. """ @property @@ -52,7 +50,7 @@ class MatAdd(BinaryOp, MatrixExpr): @dataclassabc(eq=False) -class MatSub(BinaryOp, MatrixExpr, Reducible): +class MatSub(BinaryOp, Reducible): """ Subtraction operator between matrices. """ @property @@ -69,13 +67,3 @@ class MatSub(BinaryOp, MatrixExpr, Reducible): def __str__(self) -> str: return f"({self.left} - {self.right})" - - -@dataclassabc(eq=False) -class PolyAdd(MatAdd, PolyExpr): - ... - - -@dataclassabc(eq=False) -class PolySub(MatSub, PolyExpr): - ... diff --git a/mdpoly/operations/derivative.py b/mdpoly/operations/derivative.py index 9a74941..a771ca7 100644 --- a/mdpoly/operations/derivative.py +++ b/mdpoly/operations/derivative.py @@ -4,12 +4,11 @@ from typing import TYPE_CHECKING from typing import Type, Self, cast from dataclassabc import dataclassabc -from ..abc import Expr +from ..abc import Expr, Var from ..errors import AlgebraicError from ..index import Shape, MatrixIndex, PolyIndex -from ..expressions import UnaryOp -from ..expressions.poly import PolyExpr, PolyVar +from . import UnaryOp if TYPE_CHECKING: from ..abc import ReprT @@ -20,9 +19,9 @@ if TYPE_CHECKING: @dataclassabc(eq=False) -class PolyPartialDiff(UnaryOp, PolyExpr): +class PolyPartialDiff(UnaryOp): """ Partial differentiation of scalar polynomials. """ - wrt: PolyVar + wrt: Var @property def shape(self) -> Shape: @@ -50,13 +49,12 @@ class PolyPartialDiff(UnaryOp, PolyExpr): def __str__(self) -> str: return f"(∂_{self.wrt} {self.inner})" - def replace(self, old: PolyExpr, new: PolyExpr) -> Self: + def replace(self, old: Expr, new: Expr) -> Self: """ Overloads :py:meth:`mdpoly.abc.Expr.replace` """ if self.wrt == old: - if not isinstance(new, PolyVar): + if not isinstance(new, Var): # FIXME: implement chain rule raise AlgebraicError(f"Cannot take a derivative with respect to {new}.") - self.wrt = cast(PolyVar, new) - + self.wrt = cast(Var, new) return cast(Self, Expr.replace(self, old, new)) diff --git a/mdpoly/operations/exp.py b/mdpoly/operations/exp.py index 177fd22..0f57fcb 100644 --- a/mdpoly/operations/exp.py +++ b/mdpoly/operations/exp.py @@ -9,20 +9,18 @@ from dataclasses import dataclass from ..abc import Expr, Const from ..errors import AlgebraicError -from ..expressions import BinaryOp, Reducible -from ..expressions.matrix import MatrixExpr -from ..expression.poly import PolyExpr +from . import BinaryOp, Reducible # TODO: implement matrix exponential, use caley-hamilton thm magic @dataclass(eq=False) -class MatExp(BinaryOp, MatrixExpr, Reducible): +class MatExp(BinaryOp, Reducible): def __init__(self): raise NotImplementedError @dataclass(eq=False) -class PolyExp(BinaryOp, PolyExpr, Reducible): +class PolyExp(BinaryOp, Reducible): """ Exponentiation operator between scalar polynomials. """ @property # type: ignore[override] diff --git a/mdpoly/operations/mul.py b/mdpoly/operations/mul.py index 2a8f749..7e75d26 100644 --- a/mdpoly/operations/mul.py +++ b/mdpoly/operations/mul.py @@ -10,9 +10,7 @@ from ..index import Shape from ..errors import AlgebraicError, InvalidShape from ..index import MatrixIndex, PolyIndex -from ..expressions import BinaryOp, Reducible -from ..expressions.matrix import MatrixExpr -from ..expression.poly import PolyExpr +from . import BinaryOp, Reducible if TYPE_CHECKING: from ..abc import ReprT @@ -25,7 +23,7 @@ if TYPE_CHECKING: @dataclassabc -class MatElemMul(BinaryOp, MatrixExpr): +class MatElemMul(BinaryOp): """ Elementwise Matrix Multiplication. """ @property @@ -64,7 +62,7 @@ class MatElemMul(BinaryOp, MatrixExpr): @dataclassabc -class MatScalarMul(BinaryOp, MatrixExpr): +class MatScalarMul(BinaryOp): """ Matrix-Scalar Multiplication. Assumes scalar is on the left and matrix on the right. """ @@ -101,7 +99,7 @@ class MatScalarMul(BinaryOp, MatrixExpr): @dataclass(eq=False) -class MatMul(BinaryOp, MatrixExpr): +class MatMul(BinaryOp): """ Matrix Multiplication. """ @property @@ -134,7 +132,7 @@ class MatMul(BinaryOp, MatrixExpr): @dataclass(eq=False) -class MatDotProd(BinaryOp, MatrixExpr, Reducible): +class MatDotProd(BinaryOp, Reducible): """ Dot product. """ @property @@ -157,7 +155,7 @@ class MatDotProd(BinaryOp, MatrixExpr, Reducible): @dataclass(eq=False) -class PolyMul(BinaryOp, PolyExpr): +class PolyMul(BinaryOp): """ Multiplication operator between scalar polynomials. """ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: diff --git a/mdpoly/operations/transpose.py b/mdpoly/operations/transpose.py index 750d0c0..0e4228b 100644 --- a/mdpoly/operations/transpose.py +++ b/mdpoly/operations/transpose.py @@ -4,14 +4,13 @@ from typing import TYPE_CHECKING from dataclasses import dataclass from ..expressions import UnaryOp -from ..expressions.matrix import MatrixExpr if TYPE_CHECKING: from ..index import Shape @dataclass(eq=False) -class MatTranspose(UnaryOp, MatrixExpr): +class MatTranspose(UnaryOp): """ Matrix transposition """ @property |