aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-18 19:02:44 +0100
committerNao Pross <np@0hm.ch>2024-03-18 19:02:44 +0100
commit8676dc34530be9669cd3dfc5bd5a54cb8f9ac482 (patch)
tree9036846faeec4e12e7666daa4e151c205dd8eb15
parentRename PolyRingExpr to just PolyExpr (diff)
downloadmdpoly-8676dc34530be9669cd3dfc5bd5a54cb8f9ac482.tar.gz
mdpoly-8676dc34530be9669cd3dfc5bd5a54cb8f9ac482.zip
Move operations (operator overloading etc) outside of expression class
The new class WithOps (TODO: better name) wraps around an expression tree and provides operator overloading and convenient methods. This requires a big structural change hence most things are now broken.
-rw-r--r--mdpoly/__init__.py23
-rw-r--r--mdpoly/abc.py112
-rw-r--r--mdpoly/expressions.py235
-rw-r--r--mdpoly/expressions/__init__.py66
-rw-r--r--mdpoly/expressions/matrix.py172
-rw-r--r--mdpoly/expressions/poly.py170
-rw-r--r--mdpoly/operations/__init__.py73
-rw-r--r--mdpoly/operations/add.py18
-rw-r--r--mdpoly/operations/derivative.py16
-rw-r--r--mdpoly/operations/exp.py8
-rw-r--r--mdpoly/operations/mul.py14
-rw-r--r--mdpoly/operations/transpose.py3
12 files changed, 356 insertions, 554 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index 8334d87..26025ab 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -107,8 +107,9 @@ TODO: data structure(s) that represent the polynomial
from .index import (Shape as _Shape)
-from .expressions.poly import (PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam)
-from .expressions.matrix import (MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam)
+from .expressions import (WithOps as _WithOps,
+ PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam,
+ MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam)
from .state import State as _State
@@ -138,27 +139,33 @@ class FromHelpers:
yield from map(cls, names)
-class Constant(_PolyConst, FromHelpers):
+class Constant(_WithOps, FromHelpers):
""" Constant values """
+ def __init__(self, *args, **kwargs):
+ _WithOps.__init__(self, expr=_PolyConst(*args, **kwargs))
-class Variable(_PolyVar, FromHelpers):
+class Variable(_WithOps, FromHelpers):
""" Polynomial Variable """
+ def __init__(self, *args, **kwargs):
+ _WithOps.__init__(self, expr=_PolyVar(*args, **kwargs))
-class Parameter(_PolyParam, FromHelpers):
+class Parameter(_WithOps, FromHelpers):
""" Parameter that can be substituted """
+ def __init__(self, *args, **kwargs):
+ _WithOps.__init__(self, expr=_PolyParam(*args, **kwargs))
-class MatrixConstant(_MatConst):
+class MatrixConstant(_WithOps):
""" Matrix constant """
-class MatrixVariable(_MatVar):
+class MatrixVariable(_WithOps):
""" Matrix Polynomial Variable """
-class MatrixParameter(_MatParam):
+class MatrixParameter(_WithOps):
""" Matrix Parameter """
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index 882634f..1f15361 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
-from typing import Self, Type, TypeVar, Generic, Any, Iterable, Sequence
+from typing import Self, Type, TypeVar, Generic, Iterable, Sequence
from enum import Enum, auto
from copy import copy
from abc import ABC, abstractmethod
@@ -10,7 +10,6 @@ from hashlib import sha256
from dataclassabc import dataclassabc
from .constants import NUMERICS_EPS
-from .errors import AlgebraicError
from .index import Shape
from .util import iszero
@@ -24,13 +23,6 @@ if TYPE_CHECKING:
# ┗━╸╹ ╹╹ ╹┗╸┗━╸┗━┛┗━┛╹┗━┛╹ ╹┗━┛
-class Algebra(Enum):
- """ Types of algebras. """
- none = auto()
- poly_ring = auto()
- matrix_ring = auto()
-
-
class Expr(ABC):
""" Merkle binary tree to represent a mathematical expression. """
@@ -65,19 +57,6 @@ class Expr(ABC):
@property
@abstractmethod
- def algebra(self) -> Algebra:
- """ Specifies to which algebra belongs the expression.
-
- This is used to provide nicer error messages and to avoid accidental
- mistakes (like adding a scalar and a matrix) which are hard to debug
- as expressions are lazily evalated (i.e. without this the exception for
- adding a scalar to a matrix does not occurr at the line where the two
- are added but rather where :py:meth:`mdpoly.abc.Expr.to_repr` is
- called).
- """
-
- @property
- @abstractmethod
def shape(self) -> Shape:
# TODO: Test on very large expressions. If there is a performance hit
# change to functools.cached_property.
@@ -198,78 +177,6 @@ class Expr(ABC):
return pivot
- # --- Private methods ---
-
- @staticmethod
- def _assert_same_algebra(left: Expr, right: Expr) -> None:
- if not isinstance(left, Expr) or not isinstance(right, Expr):
- return
-
- if left.algebra != right.algebra:
- raise AlgebraicError("Cannot perform algebraic operation between "
- f"{left} and {right} because they have different "
- f"algebras {left.algebra} and {right.algebra}.")
-
- @staticmethod
- def _wrap(if_type: type, wrapper_type: type, obj: Any, *args, **kwargs) -> Expr:
- """ Wrap non-expr objects.
-
- Suppose ``x`` is of type *Expr*, then we would like to be able to do
- things like ``x + 1``, so this function can be called in operator
- overloadings to wrap the 1 into a :py:class:`mdpoly.leaves.Const`. If
- ``obj`` is already of type *Expr*, this function does nothing. The
- arguments *args* and *kwargs* are forwarded to the constructor of
- *wrapper_type*.
- """
- # Do not wrap if is alreay an expression
- if isinstance(obj, Expr):
- return obj
-
- if not isinstance(obj, if_type):
- raise TypeError(f"Cannot wrap {obj} with type {wrapper_type} because "
- f"it is not of type {if_type}.")
-
- return wrapper_type(obj, *args, **kwargs)
-
- # --- Operator overloading ---
-
- def __add__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __radd__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __sub__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __rsub__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __neg__(self) -> Self:
- raise NotImplementedError
-
- def __mul__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __rmul__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __pow__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __matmul__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __rmatmul__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __truediv__(self, other: Any) -> Self:
- raise NotImplementedError
-
- def __rtruediv__(self, other: Any) -> Self:
- raise NotImplementedError
-
-
class Leaf(Expr):
""" Leaf of the binary tree. """
@@ -316,7 +223,6 @@ class Leaf(Expr):
def subtree_hash(self) -> bytes:
h = sha256()
h.update(self.__class__.__qualname__.encode("utf-8"))
- h.update(bytes(self.algebra.value))
h.update(self.name.encode("utf-8")) # type: ignore
h.update(bytes(self.shape))
return h.digest()
@@ -326,7 +232,10 @@ T = TypeVar("T")
class Const(Leaf, Generic[T]):
- """ Leaf to represent a constant. TODO: desc. """
+ """ Leaf to represent a constant.
+
+ TODO: desc.
+ """
@property
@abstractmethod
@@ -341,11 +250,17 @@ class Const(Leaf, Generic[T]):
class Var(Leaf):
- """ Leaf to reprsent a Variable. TODO: desc """
+ """ Leaf to reprsent a Variable.
+
+ TODO: desc
+ """
class Param(Leaf):
- """ Parameter. TODO: desc """
+ """ Parameter.
+
+ TODO: desc
+ """
@dataclassabc(frozen=True)
@@ -355,7 +270,6 @@ class Nothing(Leaf):
"""
name: str = "."
shape: Shape = Shape(0, 0)
- algebra: Algebra = Algebra.none
def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
raise ValueError("Nothing cannot be represented.")
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
new file mode 100644
index 0000000..c34a0aa
--- /dev/null
+++ b/mdpoly/expressions.py
@@ -0,0 +1,235 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+from dataclassabc import dataclassabc
+from dataclasses import dataclass
+from functools import wraps
+from typing import Type, TypeVar, Iterable, Callable, Sequence, Any, Self, cast
+
+from .abc import Expr, Var, Const, Param
+from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex
+from .errors import MissingParameters
+
+if TYPE_CHECKING:
+ from .abc import ReprT
+ from .index import Shape, Number
+ from .state import State
+
+
+# ┏┳┓┏━┓╺┳╸┏━┓╻┏━╸┏━╸┏━┓
+# ┃┃┃┣━┫ ┃ ┣┳┛┃┃ ┣╸ ┗━┓
+# ╹ ╹╹ ╹ ╹ ╹┗╸╹┗━╸┗━╸┗━┛
+
+
+@dataclassabc(frozen=True)
+class MatConst(Const):
+ """ Matrix constant. TODO: desc. """
+ value: Sequence[Sequence[Number]] # Row major, overloads Const.value
+ shape: Shape # overloads Expr.shape
+ name: str = "" # overloads Leaf.name
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ r = repr_type(self.shape)
+
+ for i, row in enumerate(self.value):
+ for j, val in enumerate(row):
+ r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val)
+
+ return r, state
+
+
+ def __str__(self) -> str:
+ if not self.name:
+ return repr(self.value)
+
+ return self.name
+
+
+T = TypeVar("T", bound=Var)
+
+@dataclassabc(frozen=True)
+class MatVar(Var):
+ """ Matrix of polynomial variables. TODO: desc """
+ name: str # overloads Leaf.name
+ shape: Shape # overloads Expr.shape
+
+ # TODO: review this API, can be moved elsewhere?
+ def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]:
+ for row in range(self.shape.rows):
+ for col in range(self.shape.cols):
+ var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg]
+ entry = MatrixIndex(row, col)
+
+ yield entry, var
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ r = repr_type(self.shape)
+
+ # FIXME: do not hardcode scalar type
+ for entry, var in self.to_scalars(PolyVar):
+ idx = PolyVarIndex.from_var(var, state), # important comma!
+ r.set(entry, PolyIndex(idx), 1)
+
+ return r, state
+
+ def __str__(self) -> str:
+ return self.name
+
+
+@dataclassabc(frozen=True)
+class MatParam(Param):
+ """ Matrix parameter. TODO: desc. """
+ name: str
+ shape: Shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ if self not in state.parameters:
+ raise MissingParameters("Cannot construct representation because "
+ f"value for parameter {self} was not given.")
+
+ # FIXME: add conversion to scalar variables
+ # Ignore typecheck because dataclassabc has not type stub
+ return MatConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]
+
+ def __str__(self) -> str:
+ return self.name
+
+
+# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓
+# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┗━┓
+# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸┗━┛
+
+
+@dataclassabc(frozen=True)
+class PolyVar(Var):
+ """ Variable TODO: desc """
+ name: str # overloads Leaf.name
+ shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ r = repr_type(self.shape)
+ idx = PolyVarIndex.from_var(self, state), # important comma!
+ r.set(MatrixIndex.scalar(), PolyIndex(idx), 1)
+ return r, state
+
+
+@dataclassabc(frozen=True)
+class PolyConst(Const):
+ """ Constant TODO: desc """
+ value: Number # overloads Const.value
+ name: str = "" # overloads Leaf.name
+ shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ r = repr_type(self.shape)
+ r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value)
+ return r, state
+
+
+@dataclassabc(frozen=True)
+class PolyParam(Param):
+ """ Polynomial parameter TODO: desc """
+ name: str # overloads Leaf.name
+ shape: Shape = Shape.scalar() # overloads PolyExpr.shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ if self not in state.parameters:
+ raise MissingParameters("Cannot construct representation because "
+ f"value for parameter {self} was not given.")
+
+ return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]
+
+
+# ┏━┓┏━┓┏━╸┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓
+# ┃ ┃┣━┛┣╸ ┣┳┛┣━┫ ┃ ┃┃ ┃┃┗┫┗━┓
+# ┗━┛╹ ┗━╸╹┗╸╹ ╹ ╹ ╹┗━┛╹ ╹┗━┛
+
+
+@dataclass
+class WithOps:
+ """ Monadic wrapper around :py:class:`mdpoly.abc.Expr` that adds operator
+ overloading and operations to expression objects. """
+ expr: Expr
+
+ # -- Monadic operations --
+
+ @staticmethod
+ def map(fn: Callable[[Expr], Expr]) -> Callable[[WithOps], WithOps]:
+ """ Map in the functional programming sense.
+
+ Converts a function like `(Expr -> Expr)` into
+ `(WithOps<Expr> -> WithOps<Expr>)`.
+ """
+ @wraps(fn)
+ def wrapper(e: WithOps) -> WithOps:
+ return WithOps(expr=fn(e.expr))
+ return wrapper
+
+ @staticmethod
+ def zip(fn: Callable[..., Expr]) -> Callable[..., WithOps]:
+ """ Zip, in the functional programming sense.
+
+ Converts a function like `(Expr -> Expr -> ... -> Expr)` into
+ `(WithOps<Expr> -> WithOps<Expr> -> ... -> WithOps<Expr>)`.
+ """
+ @wraps(fn)
+ def wrapper(*args: WithOps) -> WithOps:
+ return WithOps(expr=fn(*(arg.expr for arg in args)))
+ return wrapper
+
+ @staticmethod
+ def bind(fn: Callable[[Expr], WithOps]) -> Callable[[WithOps], WithOps]:
+ """ Bind in the functional programming sense. Also sometimes known as
+ flatmap.
+
+ Converts a funciton like `(Expr -> WithOps)` into `(WithOps -> WithOps)`
+ so that it can be composed with other functions.
+ """
+ @wraps(fn)
+ def wrapper(arg: WithOps) -> WithOps:
+ return fn(arg.expr)
+ return wrapper
+
+ # -- Operator overloading ---
+
+ def __add__(self, other: WithOps | Number) -> WithOps:
+ if isinstance(other, Number):
+ other = WithOps(expr=PolyConst(value=other))
+
+ return add(self, other)
+
+ def __radd__(self, other: WithOps | Number) -> WithOps:
+ if isinstance(other, Number):
+ other = WithOps(expr=PolyConst(value=other))
+
+ return add(other, self)
+
+ def __sub__(self, other: Self) -> Self:
+ raise NotImplementedError
+
+ def __rsub__(self, other: Self) -> Self:
+ raise NotImplementedError
+
+ def __neg__(self) -> Self:
+ raise NotImplementedError
+
+ def __mul__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __rmul__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __pow__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __matmul__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __rmatmul__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __truediv__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __rtruediv__(self, other: Any) -> Self:
+ raise NotImplementedError
diff --git a/mdpoly/expressions/__init__.py b/mdpoly/expressions/__init__.py
deleted file mode 100644
index 2efe342..0000000
--- a/mdpoly/expressions/__init__.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from __future__ import annotations
-from typing import TYPE_CHECKING
-
-from abc import abstractmethod
-from typing import Type
-from dataclasses import field
-from dataclassabc import dataclassabc
-
-from ..abc import Expr, Nothing
-
-if TYPE_CHECKING:
- from ..abc import ReprT
- from ..state import State
-
-
-@dataclassabc(eq=False)
-class BinaryOp(Expr):
- left: Expr = field()
- right: Expr = field()
-
-
-@dataclassabc(eq=False)
-class UnaryOp(Expr):
- right: Expr = field()
-
- @property
- def inner(self) -> Expr:
- """ Inner expression on which the operator is acting, alias for right. """
- return self.right
-
- @property
- def left(self) -> Expr:
- return Nothing()
-
- @left.setter
- def left(self, left) -> None:
- if not isinstance(left, Nothing):
- raise ValueError("Cannot set left of left-acting unary operator "
- "to something that is not of type Nothing.")
-
-
-class Reducible(Expr):
- """ Reducible Expression.
-
- Algebraic expression that can be written in terms of other (existing)
- expression objects, i.e. can be reduced to another expression made of
- simpler operations. For example subtraction can be written in term of
- addition and multiplication.
- """
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """
- return self.reduce().to_repr(repr_type, state)
-
- @abstractmethod
- def reduce(self) -> Expr:
- """ Reduce the expression to its basic elements """
-
-
-# TODO: review this idea before implementing
-# class Simplifiable(Expr):
-# """ Simplifiable Expression. """
-#
-# @abstractmethod
-# def simplify(self) -> Expr:
-# """ Simplify the expression """
diff --git a/mdpoly/expressions/matrix.py b/mdpoly/expressions/matrix.py
deleted file mode 100644
index cd38383..0000000
--- a/mdpoly/expressions/matrix.py
+++ /dev/null
@@ -1,172 +0,0 @@
-from __future__ import annotations
-from typing import TYPE_CHECKING
-
-from typing import Type, TypeVar, Iterable, Sequence
-from dataclassabc import dataclassabc
-
-from .poly import PolyVar, PolyConst
-from ..abc import Expr, Var, Const, Param, Algebra
-from ..index import MatrixIndex, PolyIndex, PolyVarIndex
-from ..errors import MissingParameters
-
-if TYPE_CHECKING:
- from ..abc import ReprT
- from ..index import Shape, Number
- from ..state import State
-
-
-class MatrixExpr(Expr):
- r""" Expression with the algebraic properties of a matrix ring and / or
- module (depending on the shape).
-
- We denote with :math:`R` a polynomial ring.
-
- - If the shape is square, like :math:`(n, n)` then this is :math:`M_n(R)`
- the ring of matrices over :math:`R`.
-
- - If the shape is something else like a row or column (:math:`(m, p)`,
- :math:`(1, n)` or :math:`(n, 1)`) this is a module, i.e. an algebra with
- addition and scalar multiplication, where the "scalars" come from
- :math:`R`.
-
- Furthermore some operators that are usually expected from matrices are
- already included (eg. transposition).
- """
-
- @property
- def algebra(self) -> Algebra:
- return Algebra.matrix_ring
-
- def __add__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, MatConst, other)
- return operations.add.MatAdd(self, other)
-
- def __sub__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, MatConst, other)
- return operations.add.MatSub(self, other)
-
- def __rsub__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, MatConst, other)
- return operations.add.MatSub(other, self)
-
- def __neg__(self):
- # FIXME: Create PolyNeg?
- return operations.mul.MatScalarMul(PolyConst(-1), self)
-
- def __mul__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, MatConst, other)
-
- # TODO: case distiction based on shapes
- return operations.mul.MatScalarMul(other, self)
-
-
- def __rmul__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, MatConst, other)
- return operations.mul.MatScalarMul(other, self)
-
- def __matmul__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, MatConst, other)
- return operations.MatMul(self, other)
-
- def __rmatmul(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, MatConst, other)
- return operations.mul.MatMul(other, self)
-
- def __truediv__(self, scalar):
- scalar = self._wrap_if_constant(scalar)
- raise NotImplementedError
-
- def transpose(self) -> MatrixExpr:
- """ Matrix transposition. """
- raise NotImplementedError
- return operations.transpose.MatTranspose(self)
-
- @property
- def T(self) -> MatrixExpr:
- """ Shorthand for :py:meth:`mdpoly.expressions.matrix.MatrixExpr.transpose`. """
- return self.transpose()
-
- def to_scalar(self, scalar_type: type):
- """ Convert to a scalar expression. """
- raise NotImplementedError
-
-
-@dataclassabc(frozen=True)
-class MatConst(Const, MatrixExpr):
- """ Matrix constant. TODO: desc. """
- value: Sequence[Sequence[Number]] # Row major, overloads Const.value
- shape: Shape # overloads Expr.shape
- name: str = "" # overloads Leaf.name
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- r = repr_type(self.shape)
-
- for i, row in enumerate(self.value):
- for j, val in enumerate(row):
- r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val)
-
- return r, state
-
-
- def __str__(self) -> str:
- if not self.name:
- return repr(self.value)
-
- return self.name
-
-
-T = TypeVar("T", bound=Var)
-
-@dataclassabc(frozen=True)
-class MatVar(Var, MatrixExpr):
- """ Matrix of polynomial variables. TODO: desc """
- name: str # overloads Leaf.name
- shape: Shape # overloads Expr.shape
-
- # TODO: review this API, can be moved elsewhere?
- def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]:
- for row in range(self.shape.rows):
- for col in range(self.shape.cols):
- var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg]
- entry = MatrixIndex(row, col)
-
- yield entry, var
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- r = repr_type(self.shape)
-
- # FIXME: do not hardcode scalar type
- for entry, var in self.to_scalars(PolyVar):
- idx = PolyVarIndex.from_var(var, state), # important comma!
- r.set(entry, PolyIndex(idx), 1)
-
- return r, state
-
- def __str__(self) -> str:
- return self.name
-
-
-@dataclassabc(frozen=True)
-class MatParam(Param, MatrixExpr):
- """ Matrix parameter. TODO: desc. """
- name: str
- shape: Shape
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- if self not in state.parameters:
- raise MissingParameters("Cannot construct representation because "
- f"value for parameter {self} was not given.")
-
- # FIXME: add conversion to scalar variables
- # Ignore typecheck because dataclassabc has not type stub
- return MatConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]
-
- def __str__(self) -> str:
- return self.name
diff --git a/mdpoly/expressions/poly.py b/mdpoly/expressions/poly.py
deleted file mode 100644
index 623e89e..0000000
--- a/mdpoly/expressions/poly.py
+++ /dev/null
@@ -1,170 +0,0 @@
-from __future__ import annotations
-from typing import TYPE_CHECKING
-
-from typing import Type, Iterable, cast
-from functools import reduce
-from itertools import chain, combinations_with_replacement
-from dataclassabc import dataclassabc
-from operator import mul as opmul
-
-from ..abc import Expr, Var, Const, Param, Algebra, ReprT
-from ..index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number
-from ..errors import AlgebraicError, MissingParameters, InvalidShape
-
-from .. import operations
-
-if TYPE_CHECKING:
- from ..index import Number
- from ..state import State
-
-
-class PolyExpr(Expr):
- r""" Expression with the algebraic behaviour of a polynomial ring.
-
- This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the
- polynomials are scalars.
- """
-
- # -- Properties ---
-
- @property
- def algebra(self) -> Algebra:
- return Algebra.poly_ring
-
- @property
- def shape(self) -> Shape:
- """ See :py:meth:`mdpoly.abc.Expr.shape`. """
- if self.left.shape != self.right.shape:
- raise InvalidShape(f"Cannot perform operation {repr(self)} with "
- f"shapes {self.left.shape} and {self.right.shape}.")
- return self.left.shape
-
- # --- Helper methods for construction ---
-
- @classmethod
- def from_powers(cls, variable: PolyExpr, exponents: Iterable[int]) -> PolyExpr:
- """ Given a variable, say :math:`x`, and a list of exponents ``[0, 3, 5]``
- returns a polynomial :math:`x^5 + x^3 + 1`.
- """
- nonzero_exponents = filter(lambda e: e != 0, exponents)
- # Ignore typecheck because dataclassabc has no typing stub
- constant_term = PolyConst(1 if 0 in exponents else 0) # type: ignore[call-arg]
- # FIXME: remove annoying 0
- return sum((variable ** e for e in nonzero_exponents), constant_term)
-
-
- @classmethod
- def make_combinations(cls, variables: Iterable[PolyVar], max_degree: int) -> Iterable[PolyExpr]:
- """ Make combinations of terms.
-
- For example given :math:`x, y` and *max_degree* of 2 generates
- :math:`x^2, xy, x, y^2, y, 1`.
- """
- # Ignore typecheck because dataclassabc has no typing stub
- vars_and_const = chain(variables, (PolyConst(1),)) # type: ignore[call-arg]
- for comb in combinations_with_replacement(vars_and_const, max_degree):
- yield cast(PolyExpr, reduce(opmul, comb))
-
- # --- Operator Overloading ---
-
- def __add__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, PolyConst, other)
- return operations.add.PolyAdd(self, other)
-
- def __radd__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, PolyConst, other)
- return operations.add.PolyAdd(other, self)
-
- def __sub__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, PolyConst, other)
- return operations.add.PolySub(self, other)
-
- def __rsub__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, PolyConst, other)
- return operations.add.PolyAdd(other, self)
-
- def __neg__(self):
- # FIXME: Create PolyNeg?
- return operations.mul.PolyMul(PolyConst(-1), self)
-
- def __mul__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, PolyConst, other)
- return operations.mul.PolyMul(self, other)
-
- def __rmul__(self, other):
- self._assert_same_algebra(self, other)
- other = self._wrap(Number, PolyConst, other)
- return operations.mul.PolyMul(other, self)
-
- def __matmul__(self, other):
- raise AlgebraicError("Cannot perform matrix multiplication in polynomial ring (they are scalars).")
-
- def __rmatmul__(self, other):
- self.__rmatmul__(other)
-
- def __truediv__(self, other):
- other = self._wrap(Number, PolyConst, other)
- if not isinstance(other, PolyConst | PolyParam):
- raise AlgebraicError("Cannot divide by variables in polynomial ring.")
-
- return operations.mul.PolyMul(self, other)
-
- def __rtruediv__(self, other):
- raise AlgebraicError("Cannot perform right division in polynomial ring.")
-
- def __pow__(self, other):
- other = self._wrap(Number, PolyConst, other)
- if not isinstance(other, PolyConst | PolyParam):
- raise AlgebraicError(f"Cannot raise to powers of type {type(other)} in "
- "polynomial ring. Only constants and parameters are allowed.")
- return operations.exp.PolyExp(left=self, right=other)
-
- # -- Other mathematical operations ---
-
- def diff(self, wrt: PolyVar) -> operations.derivative.PolyPartialDiff:
- return operations.derivative.PolyPartialDiff(right=self, wrt=wrt)
-
-
-@dataclassabc(frozen=True)
-class PolyVar(Var, PolyExpr):
- """ Variable TODO: desc """
- name: str # overloads Leaf.name
- shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- r = repr_type(self.shape)
- idx = PolyVarIndex.from_var(self, state), # important comma!
- r.set(MatrixIndex.scalar(), PolyIndex(idx), 1)
- return r, state
-
-
-@dataclassabc(frozen=True)
-class PolyConst(Const, PolyExpr):
- """ Constant TODO: desc """
- value: Number # overloads Const.value
- name: str = "" # overloads Leaf.name
- shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- r = repr_type(self.shape)
- r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value)
- return r, state
-
-
-@dataclassabc(frozen=True)
-class PolyParam(Param, PolyExpr):
- """ Polynomial parameter TODO: desc """
- name: str # overloads Leaf.name
- shape: Shape = Shape.scalar() # overloads PolyExpr.shape
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- if self not in state.parameters:
- raise MissingParameters("Cannot construct representation because "
- f"value for parameter {self} was not given.")
-
- return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]
diff --git a/mdpoly/operations/__init__.py b/mdpoly/operations/__init__.py
index e69de29..752358c 100644
--- a/mdpoly/operations/__init__.py
+++ b/mdpoly/operations/__init__.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+from typing import Type
+from abc import abstractmethod
+from dataclasses import field
+from dataclassabc import dataclassabc
+
+from ..abc import ReprT, Expr, Nothing
+
+if TYPE_CHECKING:
+ from ..state import State
+
+
+@dataclassabc(eq=False)
+class BinaryOp(Expr):
+ """ Binary operator.
+
+ TODO: desc
+ """
+ left: Expr = field()
+ right: Expr = field()
+
+
+@dataclassabc(eq=False)
+class UnaryOp(Expr):
+ """ Unary operator.
+
+ TODO: desc
+ """
+ right: Expr = field()
+
+ @property
+ def inner(self) -> Expr:
+ """ Inner expression on which the operator is acting, alias for right. """
+ return self.right
+
+ @property
+ def left(self) -> Expr:
+ return Nothing()
+
+ @left.setter
+ def left(self, left) -> None:
+ if not isinstance(left, Nothing):
+ raise ValueError("Cannot set left of left-acting unary operator "
+ "to something that is not of type Nothing.")
+
+
+class Reducible(Expr):
+ """ Reducible Expression.
+
+ Algebraic expression that can be written in terms of other (existing)
+ expression objects, i.e. can be reduced to another expression made of
+ simpler operations. For example subtraction can be written in term of
+ addition and multiplication.
+ """
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """
+ return self.reduce().to_repr(repr_type, state)
+
+ @abstractmethod
+ def reduce(self) -> Expr:
+ """ Reduce the expression to its basic elements """
+
+
+# TODO: review this idea before implementing
+# class Simplifiable(Expr):
+# """ Simplifiable Expression. """
+#
+# @abstractmethod
+# def simplify(self) -> Expr:
+# """ Simplify the expression """
diff --git a/mdpoly/operations/add.py b/mdpoly/operations/add.py
index b0174ef..30ad3b4 100644
--- a/mdpoly/operations/add.py
+++ b/mdpoly/operations/add.py
@@ -7,9 +7,7 @@ from dataclassabc import dataclassabc
from ..abc import Expr
from ..errors import AlgebraicError
-from ..expressions import BinaryOp, Reducible
-from ..expressions.matrix import MatrixExpr
-from ..expressions.poly import PolyExpr
+from . import BinaryOp, Reducible
if TYPE_CHECKING:
from ..abc import ReprT
@@ -18,7 +16,7 @@ if TYPE_CHECKING:
@dataclassabc(eq=False)
-class MatAdd(BinaryOp, MatrixExpr):
+class MatAdd(BinaryOp):
""" Addition operator between matrices. """
@property
@@ -52,7 +50,7 @@ class MatAdd(BinaryOp, MatrixExpr):
@dataclassabc(eq=False)
-class MatSub(BinaryOp, MatrixExpr, Reducible):
+class MatSub(BinaryOp, Reducible):
""" Subtraction operator between matrices. """
@property
@@ -69,13 +67,3 @@ class MatSub(BinaryOp, MatrixExpr, Reducible):
def __str__(self) -> str:
return f"({self.left} - {self.right})"
-
-
-@dataclassabc(eq=False)
-class PolyAdd(MatAdd, PolyExpr):
- ...
-
-
-@dataclassabc(eq=False)
-class PolySub(MatSub, PolyExpr):
- ...
diff --git a/mdpoly/operations/derivative.py b/mdpoly/operations/derivative.py
index 9a74941..a771ca7 100644
--- a/mdpoly/operations/derivative.py
+++ b/mdpoly/operations/derivative.py
@@ -4,12 +4,11 @@ from typing import TYPE_CHECKING
from typing import Type, Self, cast
from dataclassabc import dataclassabc
-from ..abc import Expr
+from ..abc import Expr, Var
from ..errors import AlgebraicError
from ..index import Shape, MatrixIndex, PolyIndex
-from ..expressions import UnaryOp
-from ..expressions.poly import PolyExpr, PolyVar
+from . import UnaryOp
if TYPE_CHECKING:
from ..abc import ReprT
@@ -20,9 +19,9 @@ if TYPE_CHECKING:
@dataclassabc(eq=False)
-class PolyPartialDiff(UnaryOp, PolyExpr):
+class PolyPartialDiff(UnaryOp):
""" Partial differentiation of scalar polynomials. """
- wrt: PolyVar
+ wrt: Var
@property
def shape(self) -> Shape:
@@ -50,13 +49,12 @@ class PolyPartialDiff(UnaryOp, PolyExpr):
def __str__(self) -> str:
return f"(∂_{self.wrt} {self.inner})"
- def replace(self, old: PolyExpr, new: PolyExpr) -> Self:
+ def replace(self, old: Expr, new: Expr) -> Self:
""" Overloads :py:meth:`mdpoly.abc.Expr.replace` """
if self.wrt == old:
- if not isinstance(new, PolyVar):
+ if not isinstance(new, Var):
# FIXME: implement chain rule
raise AlgebraicError(f"Cannot take a derivative with respect to {new}.")
- self.wrt = cast(PolyVar, new)
-
+ self.wrt = cast(Var, new)
return cast(Self, Expr.replace(self, old, new))
diff --git a/mdpoly/operations/exp.py b/mdpoly/operations/exp.py
index 177fd22..0f57fcb 100644
--- a/mdpoly/operations/exp.py
+++ b/mdpoly/operations/exp.py
@@ -9,20 +9,18 @@ from dataclasses import dataclass
from ..abc import Expr, Const
from ..errors import AlgebraicError
-from ..expressions import BinaryOp, Reducible
-from ..expressions.matrix import MatrixExpr
-from ..expression.poly import PolyExpr
+from . import BinaryOp, Reducible
# TODO: implement matrix exponential, use caley-hamilton thm magic
@dataclass(eq=False)
-class MatExp(BinaryOp, MatrixExpr, Reducible):
+class MatExp(BinaryOp, Reducible):
def __init__(self):
raise NotImplementedError
@dataclass(eq=False)
-class PolyExp(BinaryOp, PolyExpr, Reducible):
+class PolyExp(BinaryOp, Reducible):
""" Exponentiation operator between scalar polynomials. """
@property # type: ignore[override]
diff --git a/mdpoly/operations/mul.py b/mdpoly/operations/mul.py
index 2a8f749..7e75d26 100644
--- a/mdpoly/operations/mul.py
+++ b/mdpoly/operations/mul.py
@@ -10,9 +10,7 @@ from ..index import Shape
from ..errors import AlgebraicError, InvalidShape
from ..index import MatrixIndex, PolyIndex
-from ..expressions import BinaryOp, Reducible
-from ..expressions.matrix import MatrixExpr
-from ..expression.poly import PolyExpr
+from . import BinaryOp, Reducible
if TYPE_CHECKING:
from ..abc import ReprT
@@ -25,7 +23,7 @@ if TYPE_CHECKING:
@dataclassabc
-class MatElemMul(BinaryOp, MatrixExpr):
+class MatElemMul(BinaryOp):
""" Elementwise Matrix Multiplication. """
@property
@@ -64,7 +62,7 @@ class MatElemMul(BinaryOp, MatrixExpr):
@dataclassabc
-class MatScalarMul(BinaryOp, MatrixExpr):
+class MatScalarMul(BinaryOp):
""" Matrix-Scalar Multiplication. Assumes scalar is on the left and matrix
on the right. """
@@ -101,7 +99,7 @@ class MatScalarMul(BinaryOp, MatrixExpr):
@dataclass(eq=False)
-class MatMul(BinaryOp, MatrixExpr):
+class MatMul(BinaryOp):
""" Matrix Multiplication. """
@property
@@ -134,7 +132,7 @@ class MatMul(BinaryOp, MatrixExpr):
@dataclass(eq=False)
-class MatDotProd(BinaryOp, MatrixExpr, Reducible):
+class MatDotProd(BinaryOp, Reducible):
""" Dot product. """
@property
@@ -157,7 +155,7 @@ class MatDotProd(BinaryOp, MatrixExpr, Reducible):
@dataclass(eq=False)
-class PolyMul(BinaryOp, PolyExpr):
+class PolyMul(BinaryOp):
""" Multiplication operator between scalar polynomials. """
def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
diff --git a/mdpoly/operations/transpose.py b/mdpoly/operations/transpose.py
index 750d0c0..0e4228b 100644
--- a/mdpoly/operations/transpose.py
+++ b/mdpoly/operations/transpose.py
@@ -4,14 +4,13 @@ from typing import TYPE_CHECKING
from dataclasses import dataclass
from ..expressions import UnaryOp
-from ..expressions.matrix import MatrixExpr
if TYPE_CHECKING:
from ..index import Shape
@dataclass(eq=False)
-class MatTranspose(UnaryOp, MatrixExpr):
+class MatTranspose(UnaryOp):
""" Matrix transposition """
@property