aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/__init__.py22
-rw-r--r--mdpoly/abc.py93
-rw-r--r--mdpoly/algebra.py437
-rw-r--r--mdpoly/leaves.py6
4 files changed, 265 insertions, 293 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index 42a800d..fddd62b 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -5,8 +5,8 @@
from .abc import (Shape as _Shape)
-from .algebra import (PolyRingAlgebra as _PolyRingAlgebra,
- MatrixAlgebra as _MatrixAlgebra)
+from .algebra import (PolyRingExpr as _PolyRingExpr,
+ MatrixExpr as _MatrixExpr)
from .leaves import (Const as _Const,
Var as _Var,
@@ -14,7 +14,7 @@ from .leaves import (Const as _Const,
from .state import State as _State
-from typing import Self, Sequence
+from typing import Self, Iterable
# ┏━╸╻ ╻┏━┓┏━┓┏━┓╺┳╸ ┏━┓┏━┓ ╻┏━┓
@@ -32,34 +32,34 @@ State = _State
# FIXME: move out of this file
class _FromHelpers:
@classmethod
- def from_names(cls, comma_separated_names: str, strip: bool =True) -> Sequence[Self]:
+ def from_names(cls, comma_separated_names: str, strip: bool =True) -> Iterable[Self]:
""" Generate scalar variables from comma separated list of names """
- names = comma_separated_names.split(",")
+ names: Iterable = comma_separated_names.split(",")
if strip:
names = map(str.strip, names)
yield from map(cls, names)
-class Constant(_Const, _PolyRingAlgebra, _FromHelpers):
+class Constant(_Const, _PolyRingExpr, _FromHelpers):
""" Constant values """
-class Variable(_Var, _PolyRingAlgebra, _FromHelpers):
+class Variable(_Var, _PolyRingExpr, _FromHelpers):
""" Polynomial Variable """
-class Parameter(_Param, _PolyRingAlgebra):
+class Parameter(_Param, _PolyRingExpr):
""" Parameter that can be substituted """
-class MatrixConstant(_Const, _MatrixAlgebra):
+class MatrixConstant(_Const, _PolyRingExpr):
""" Matrix constant """
-class MatrixVariable(_Var, _MatrixAlgebra):
+class MatrixVariable(_Var, _MatrixExpr):
""" Matrix Polynomial Variable """
-class MatrixParameter(_Param, _PolyRingAlgebra):
+class MatrixParameter(_Param, _MatrixExpr):
""" Matrix Parameter """
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index 6637db5..5d43055 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -1,36 +1,36 @@
""" Abstract Base Classes of MDPoly """
+from __future__ import annotations
from .index import Number, Shape, MatrixIndex, PolyIndex
from .constants import NUMERICS_EPS
from .util import iszero
-from typing import Self, Iterable, Sequence
+from typing import Self, Any, Iterable, Sequence
from enum import Enum, auto
from abc import ABC, abstractmethod
class Algebra(Enum):
""" Types of algebras. """
+ none = auto()
poly_ring = auto()
matrix_ring = auto()
-class Leaf(ABC):
- """ Leaf of the binary tree. """
- name: str
- shape: Shape
-
-
class Expr(ABC):
""" Binary tree to represent a mathematical expression. """
@property
+ def is_leaf(self) -> bool:
+ return False
+
+ @property
@abstractmethod
- def left(self) -> Self | Leaf: ...
+ def left(self) -> Self: ...
@property
@abstractmethod
- def right(self) -> Self | Leaf: ...
+ def right(self) -> Self: ...
@property
@abstractmethod
@@ -48,19 +48,19 @@ class Expr(ABC):
name = self.__class__.__qualname__
return f"{name}(left={self.left}, right={self.right})"
- def children(self) -> Sequence[Self | Leaf]:
+ def children(self) -> Sequence[Expr]:
""" Iterate over the two nodes """
return self.left, self.right
- def leaves(self) -> Iterable[Leaf]:
+ def leaves(self) -> Iterable[Expr]:
""" Returns the leaves of the tree. This is done recursively and in
:math:`\mathcal{O}(n)`."""
- if isinstance(self.left, Leaf):
+ if self.left.is_leaf:
yield self.left
else:
yield from self.left.leaves()
- if isinstance(self.right, Leaf):
+ if self.right.is_leaf:
yield self.right
else:
yield from self.right.leaves()
@@ -109,9 +109,74 @@ class Expr(ABC):
return replace_all(self)
- def __iter__(self) -> Iterable[Self | Leaf]:
+ def __iter__(self) -> Iterable[Expr]:
yield from self.children()
+ # --- Operator overloading ---
+
+ @staticmethod
+ def _wrap(if_type: type, wrapper_type: type, obj):
+ # Do not wrap if is alreay an expression
+ if isinstance(obj, Expr):
+ return obj
+
+ if not isinstance(obj, if_type):
+ raise TypeError
+
+ return wrapper_type(obj)
+
+ def __add__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __radd__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __sub__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __rsub__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __neg__(self) -> Self:
+ raise NotImplementedError
+
+ def __mul__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __rmul__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __pow__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __matmul__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __rmatmul__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __truediv__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+ def __rtruediv__(self, other: Any) -> Self:
+ raise NotImplementedError
+
+
+class Leaf(Expr):
+ """ Leaf of the binary tree. """
+ name: str
+
+ @property
+ def is_leaf(self):
+ return True
+
+ @property
+ def left(self):
+ ...
+
+ @property
+ def right(self):
+ ...
class Rel(ABC):
""" Relation between two expressions. """
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index a07c4d6..aa12ae3 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -1,79 +1,60 @@
""" Algebraic Structures for Expressions """
+from __future__ import annotations
-from .abc import Leaf, Expr, Repr
+from .abc import Algebra, Expr, Repr
from .leaves import Nothing, Const, Param, Var
from .errors import AlgebraicError, InvalidShape
from .state import State
-from .index import Shape, MatrixIndex, PolyIndex
+from .index import Shape, MatrixIndex, PolyIndex, Number
from .representations import HasRepr
-from typing import Protocol, Self, Any, runtime_checkable
-from functools import wraps, reduce, cached_property
+from typing import cast, Protocol, Self, Sequence
+from functools import reduce
from itertools import product
-from enum import Enum
from abc import abstractmethod
import operator
-@runtime_checkable
-class AlgebraicStructure(Protocol):
- """ Provides methods to enforce algebraic closure of operations.
+# To make the type checker happy
+class ExprWithRepr(Expr, HasRepr):
- Internally, this is done by adding a class variable ``_algebra`` of type
- :py:class:`mdpoly.algebra.AlgebraicStructure.Algebra`, and checking that
- they match in the methods for operator overloading.
- """
-
- class Algebra(Enum):
- """ Types of algebras. """
- poly_ring = 0
- matrix_ring = 1
-
- _algebra: Algebra
- _parameter: type
- _constant: type
+ def children(self) -> Sequence[Self]: # type: ignore[override]
+ return cast(Sequence[Self], Expr.children(self))
- @classmethod
- def _is_constant(cls, other: Self) -> bool:
- return isinstance(other, cls._constant)
- @classmethod
- def _is_parameter(cls, other: Self) -> bool:
- return isinstance(other, cls._parameter)
+class BinaryOp(ExprWithRepr):
+ def __init__(self, left, right):
+ self._left = left
+ self._right = right
- @classmethod
- def _is_const_or_param(cls, other: Self) -> bool:
- return cls._is_constant(other) or cls._is_parameter(other)
+ @property
+ def left(self) -> ExprWithRepr: # type: ignore[override]
+ return self._left
- @classmethod
- def _assert_same_algebra(cls, other: Self) -> None:
- if not cls._algebra == other._algebra:
- if not cls._is_constant(other):
- raise AlgebraicError("Cannot perform operation between types from "
- f"different algebraic structures {cls} ({cls._algebra}) "
- f"and {type(other)} ({other._algebra})")
+ @property
+ def right(self) -> ExprWithRepr: # type: ignore[override]
+ return self._right
- @classmethod
- def _wrap_if_constant(cls, other: Any):
- if isinstance(other, AlgebraicStructure):
- cls._assert_same_algebra(other)
- return other
- if isinstance(other, cls._constant):
- return other
+class UnaryOp(ExprWithRepr):
+ def __init__(self, inner):
+ self._inner = inner
- return cls._constant(name="", value=other)
+ @property
+ def inner(self) -> Expr:
+ return self._inner
+
+ @property
+ def left(self) -> ExprWithRepr: # type: ignore[override]
+ return self._inner
- # FIXME: add shape rules in algebra
- # @classmethod
- # def _assert_same_shape(cls: T, other: T) -> None:
- # if not cls.shape == other.shape:
- # raise AlgebraicError("Cannot perform algebraic operations between "
- # f"objects with different shapes {cls.shape} and {other.shape}")
+ @property
+ def right(self) -> ExprWithRepr: # type: ignore[override]
+ return Nothing() # type: ignore
-class ReducibleExpr(HasRepr, Protocol):
+class Reducible(HasRepr, Protocol):
""" Reducible Expression
Algebraic expression that can be written in terms of other (existing)
@@ -91,80 +72,83 @@ class ReducibleExpr(HasRepr, Protocol):
""" Reduce the expression to its basic elements """
-# FIXME: The type checker really hates this trick, what is a better solution?
-def binary_operator(left_type: AlgebraicStructure, right_type: AlgebraicStructure):
- """ Class decorator that adds constructor for binary operations of Expr.
-
- Classes that inherit :py:class:`mdpoly.abc.Expr` take values for *left* and
- *right* to represent the operands. This binary operator specifies the
- algebra (:py:class:`mdpoly.algebra.AlgebraicStructure`) that *left* and
- *right* have to respect in order for the operation to be correct.
- Concretely, this is to raise and exception for eg. when a scalar is added
- to a matrix, etc.
- """
- # TODO: add right_shape and left_shape for matrices
- def decorator(cls: Expr) -> Expr:
- init_cls = cls.__init__
- @wraps(cls.__init__)
- def new_init_cls(self, left, right, *args, **kwargs):
- init_cls(self, *args, **kwargs)
- # Wrong algebra
- if not isinstance(left, left_type) or not isinstance(right, right_type):
- # None of the two is a Leaf. This is a workaround because
- # adding the algebra to the base Leaf types Const, Var, etc. is
- # not possible without having circular imports. For the exported types
- # Constant, Variable, etc. this is not a problem.
- if not isinstance(left, Leaf) and not isinstance(right, Leaf):
- raise AlgebraicError(
- "Cannot perform operation between types from "
- f"different algebraic structures {type(left)} ({left._algebra}) "
- f"and {type(right)} ({right._algebra})")
-
- self.left, self.right = left, right
- cls.__init__ = new_init_cls
- return cls
- return decorator
-
-
-# FIXME: same as binary_operator
-def unary_operator(inner_type: AlgebraicStructure):
- """ Class decorator that adss constructor for unary operations of Expr
-
- This is the special case of the binary operator, where only *left* is used
- (because it is assumed that the operator acts from the left. The field for
- *right* contains a placeholder :py:class:`mdpoly.leaves.Nothing`,
- furthermore a new field *inner* is added as alias to *left*.
-
- See also :py:func:`mdpoly.algebra.binary_operator`.
- """
- def decorator(cls: Expr) -> Expr:
- """ Class decorator to modify constructor of unary operators. """
- init_cls = cls.__init__
- @wraps(cls.__init__)
- def new_init_cls(self, left, *args, **kwargs):
- init_cls(self, *args, **kwargs)
- self.left, self.right = left, Nothing
-
- def inner(self):
- return self.left
+# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓
+# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫
+# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹
- cls.__init__ = new_init_cls
- cls.inner = property(inner)
- return cls
- return decorator
+class PolyRingExpr(Expr):
+ r""" Endows with the algebraic structure of a polynomial ring.
+ This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the
+ polynomials are scalars.
+ """
-class Add(Expr, HasRepr):
- """ Generic addition (no type check) """
+ @property
+ def algebra(self):
+ return
@property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
if self.left.shape != self.right.shape:
- raise InvalidShape(f"Cannot add {self.left} and {self.right} with "
+ raise InvalidShape(f"Cannot perform operation {repr(self)} with "
f"shapes {self.left.shape} and {self.right.shape}.")
return self.left.shape
+ def __add__(self, other):
+ other = self._wrap(Number, Const, other)
+ return PolyAdd(self, other)
+
+ def __radd__(self, other):
+ other = self._wrap(Number, Const, other)
+ return PolyAdd(other, self)
+
+ def __sub__(self, other):
+ other = self._wrap(Number, Const, other)
+ return PolySub(self, other)
+
+ def __rsub__(self, other):
+ other = self._wrap(Number, Const, other)
+ return PolyAdd(other, self)
+
+ def __neg__(self):
+ return PolyMul(self._constant(-1), self)
+
+ def __mul__(self, other):
+ other = self._wrap(Number, Const, other)
+ return PolyMul(self, other)
+
+ def __rmul__(self, other):
+ other = self._wrap(Number, Const, other)
+ return PolyMul(other, self)
+
+ def __matmul__(self, other):
+ raise AlgebraicError("Cannot perform matrix multiplication in polynomial ring (they are scalars).")
+
+ def __rmatmul__(self, other):
+ self.__rmatmul__(other)
+
+ def __truediv__(self, other):
+ other = self._wrap(Number, Const, other)
+ if not self._is_const_or_param(other):
+ raise AlgebraicError("Cannot divide by variables in polynomial ring.")
+
+ return PolyMul(self, other)
+
+ def __rtruediv__(self, other):
+ raise AlgebraicError("Cannot perform right division in polynomial ring.")
+
+ def __pow__(self, other):
+ other = self._wrap(Number, Const, other)
+ if not isinstance(other, Const | Param):
+ raise AlgebraicError(f"Cannot raise to powers of type {type(other)} in "
+ "polynomial ring. Only constants and parameters are allowed.")
+ return PolyExp(self, other)
+
+
+class PolyAdd(BinaryOp, PolyRingExpr):
+ """ Addition operator between scalar polynomials. """
+
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
# Make a new empty representation
@@ -186,34 +170,19 @@ class Add(Expr, HasRepr):
return f"({self.left} + {self.right})"
-class Sub(ReducibleExpr):
- """ Generic subtraction operator (no type check) """
-
- @property
- def shape(self) -> Shape:
- """ See :py:meth:`mdpoly.abc.Expr.shape`. """
- if self.left.shape != self.right.shape:
- raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
- return self.left.shape
+class PolySub(BinaryOp, PolyRingExpr, Reducible):
+ """ Subtraction operator between scalar polynomials. """
def reduce(self) -> HasRepr:
""" See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
- # return self.left + (-1 * self.right)
- return Add(self.left, Mul(Const(value=-1), self.right))
+ return self.left + (-1 * self.right)
def __repr__(self) -> str:
return f"({self.left} - {self.right})"
-class Mul(Expr, HasRepr):
- """ Generic multiplication operator (no type check). """
-
- @property
- def shape(self) -> Shape:
- """ See :py:meth:`mdpoly.abc.Expr.shape`. """
- if self.left.shape != self.right.shape:
- raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
- return self.left.shape
+class PolyMul(BinaryOp, PolyRingExpr):
+ """ Multiplication operator between scalar polynomials. """
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
@@ -241,19 +210,23 @@ class Mul(Expr, HasRepr):
return f"({self.left} * {self.right})"
-class Exp(ReducibleExpr):
+class PolyExp(BinaryOp, PolyRingExpr, Reducible):
""" Generic exponentiation (no type check). """
@property
- def shape(self) -> Shape:
- """ See :py:meth:`mdpoly.abc.Expr.shape`. """
- if self.left.shape != self.right.shape:
- raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
- return self.left.shape
+ def right(self) -> Const: # type: ignore[override]
+ if not isinstance(super().right, Const):
+ raise AlgebraicError(f"Cannot raise {self.left} to {self.right} because"
+ f"{self.right} is not a constant.")
+
+ return cast(Const, super().right)
def reduce(self) -> HasRepr:
""" See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
var = self.left
+ if not isinstance(self.right.value, int):
+ raise NotImplementedError
+
ntimes = self.right.value - 1
return reduce(operator.mul, (var for _ in range(ntimes)), var)
@@ -261,99 +234,10 @@ class Exp(ReducibleExpr):
return f"({self.left} ** {self.right})"
-# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓
-# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫
-# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹
-
-class PolyRingAlgebra(AlgebraicStructure, Protocol):
- r""" Endows with the algebraic structure of a polynomial ring.
-
- This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the
- polynomials are scalars.
- """
- _algebra: AlgebraicStructure.Algebra = AlgebraicStructure.Algebra.poly_ring
- _parameter: type = Param
- _constant: type = Const
-
- def __add__(self, other):
- other = self._wrap_if_constant(other)
- return PolyAdd(self, other)
-
- def __radd__(self, other):
- other = self._wrap_if_constant(other)
- return PolyAdd(other, self)
-
- def __sub__(self, other):
- other = self._wrap_if_constant(other)
- return PolySub(self, other)
-
- def __rsub__(self, other):
- other = self._wrap_if_constant(other)
- return PolyAdd(other, self)
-
- def __neg__(self, other):
- return PolyMul(self._constant(-1), other)
-
- def __mul__(self, other):
- other = self._wrap_if_constant(other)
- return PolyMul(self, other)
-
- def __rmul__(self, other):
- other = self._wrap_if_constant(other)
- return PolyMul(other, self)
-
- def __truediv__(self, other):
- other = self._wrap_if_constant(other)
- if not self._is_const_or_param(other):
- raise AlgebraicError("Cannot divide by variables in polynomial ring.")
-
- return PolyMul(self, other)
-
- def __pow__(self, other):
- other = self._wrap_if_constant(other)
- if not self._is_const_or_param(other):
- raise AlgebraicError(f"Cannot raise to powers of type {type(other)} in "
- "polynomial ring. Only constants and parameters are allowed.")
- return PolyExp(self, other)
-
-
-@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class PolyAdd(Add, PolyRingAlgebra):
- """ Addition operator between scalar polynomials. """
-
-
-@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class PolySub(Sub, PolyRingAlgebra):
- """ Subtraction operator between scalar polynomials. """
-
-
-@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class PolyMul(Mul, PolyRingAlgebra):
- """ Multiplication operator between scalar polynomials. """
-
-
-@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class PolyDiv(ReducibleExpr, PolyRingAlgebra):
- """ Division of scalar polynomial by scalar. """
-
- def reduce(self) -> HasRepr:
- """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
- inverse = self._constant(value=1/self.right.value)
- return PolyMul(inverse, self.left)
-
- def __repr__(self) -> str:
- return f"({self.left} / {self.right})"
-
-
-@binary_operator(PolyRingAlgebra, PolyRingAlgebra)
-class PolyExp(Exp, PolyRingAlgebra):
- """ Exponentiation operator of scalar polynomials. """
-
-
-@unary_operator(PolyRingAlgebra)
-class PartialDiff(Expr, HasRepr, PolyRingAlgebra):
+class PolyPartialDiff(UnaryOp, PolyRingExpr):
""" Partial differentiation of scalar polynomials """
- def __init__(self, with_respect_to: Var):
+ def __init__(self, inner: Expr, with_respect_to: Var):
+ UnaryOp.__init__(self, inner)
self.wrt = with_respect_to
@property
@@ -384,7 +268,7 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra):
# ╹┗╸╹ ╹ ╹ ╹┗━┛╹ ╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹
-class RationalFieldAlgebra(AlgebraicStructure, Protocol):
+class RationalFieldExpr(Expr):
""" Endows with the algebraic structure of a field of rational functions """
@@ -393,7 +277,7 @@ class RationalFieldAlgebra(AlgebraicStructure, Protocol):
# ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫
# ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹
-class MatrixAlgebra(AlgebraicStructure, Protocol):
+class MatrixExpr(Expr):
r""" Endows with the algebraic structure of a matrix ring and / or module
depending on the shape.
@@ -411,37 +295,36 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
already included (eg. transposition).
"""
- _algebra: AlgebraicStructure.Algebra = AlgebraicStructure.Algebra.matrix_ring
- # FIXME: consider MatParam or something like that?
- _parameter: type = Param
- _constant: type = Const
+ @property
+ def algebra(self):
+ return Algebra.matrix_algebra
def __add__(self, other):
- other = self._wrap_if_constant(other)
+ other = self._wrap(Number, Const, other)
return MatAdd(self, other)
def __sub__(self, other):
- other = self._wrap_if_constant(other)
+ other = self._wrap(Number, Const, other)
return MatSub(self, other)
def __rsub__(self, other):
- other = self._wrap_if_constant(other)
+ other = self._wrap(Number, Const, other)
return MatSub(other, self)
def __mul__(self, other):
- other = self._wrap_if_constant(other)
+ other = self._wrap(Number, Const, other)
return MatScalarMul(other, self)
def __rmul__(self, other):
- scalar = self._wrap_if_constant(other)
- return MatScalarMul(scalar, self)
+ other = self._wrap(Number, Const, other)
+ return MatScalarMul(other, self)
def __matmul__(self, other):
- other = self._wrap_if_constant(other)
+ other = self._wrap(Number, Const, other)
return MatMul(self, other)
def __rmatmul(self, other):
- other = self._wrap_if_constant(other)
+ other = self._wrap(Number, Const, other)
return MatMul(other, self)
def __truediv__(self, scalar):
@@ -457,31 +340,54 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
""" Shorthand for :py:meth:`mdpoly.algebra.MatrixAlgebra.transpose`. """
return self.transpose()
- # def to_scalar(self, scalar_type: type):
- # """ Convert to a scalar expression. """
- # raise NotImplementedError
+ def to_scalar(self, scalar_type: type):
+ """ Convert to a scalar expression. """
+ raise NotImplementedError
+
+
+class MatAdd(BinaryOp, PolyRingExpr):
+ """ Addition operator between matrices. """
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
+ # Make a new empty representation
+ r = repr_type()
+
+ # Make representations for existing stuff
+ for node in self.children():
+ nrepr, state = node.to_repr(repr_type, state)
+
+ # Process non-zero terms
+ for entry in nrepr.entries():
+ for term in nrepr.terms(entry):
+ s = r.at(entry, term) + nrepr.at(entry, term)
+ r.set(entry, term, s)
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"({self.left} + {self.right})"
-@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class MatAdd(Add, MatrixAlgebra):
- """ Matrix Addition. """
+class MatSub(BinaryOp, PolyRingExpr, Reducible):
+ """ Subtraction operator between matrices. """
+ def reduce(self) -> HasRepr:
+ """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
+ return self.left + (-1 * self.right)
-@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class MatSub(Sub, MatrixAlgebra):
- """ Matrix Subtraction. """
+ def __repr__(self) -> str:
+ return f"({self.left} - {self.right})"
-@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class MatElemMul(Mul, MatrixAlgebra):
+class MatElemMul(BinaryOp, MatrixExpr):
""" Elementwise Matrix Multiplication. """
def __repr__(self) -> str:
return f"({self.left} .* {self.right})"
-@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class MatScalarMul(ReducibleExpr, MatrixAlgebra):
+class MatScalarMul(BinaryOp, MatrixExpr):
""" Matrix-Scalar Multiplication. """
@property
@@ -496,8 +402,7 @@ class MatScalarMul(ReducibleExpr, MatrixAlgebra):
raise InvalidShape(f"Either {self.left} or {self.right} must be a scalar.")
-@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class MatDotProd(Expr, MatrixAlgebra):
+class MatDotProd(BinaryOp, MatrixExpr):
""" Dot product. """
@property
@@ -514,8 +419,7 @@ class MatDotProd(Expr, MatrixAlgebra):
return Shape.scalar()
-@binary_operator(MatrixAlgebra, MatrixAlgebra)
-class MatMul(Expr, MatrixAlgebra):
+class MatMul(BinaryOp, MatrixExpr):
""" Matrix Multiplication. """
@property
@@ -532,11 +436,10 @@ class MatMul(Expr, MatrixAlgebra):
return f"({self.left} @ {self.right})"
-@unary_operator(MatrixAlgebra)
-class MatTranspose(Expr, MatrixAlgebra):
+class MatTranspose(UnaryOp, MatrixExpr):
""" Matrix transposition """
@property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
- return Shape(self.inner.cols, self.inner.rows)
+ return Shape(self.inner.shape.cols, self.inner.shape.rows)
diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py
index af43435..83324a3 100644
--- a/mdpoly/leaves.py
+++ b/mdpoly/leaves.py
@@ -1,4 +1,4 @@
-from .abc import Leaf, Repr
+from .abc import Algebra, Leaf, Repr
from .index import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex
from .state import State
from .errors import MissingParameters
@@ -14,6 +14,7 @@ class Nothing(Leaf):
"""
name: str = "<nothing>"
shape: Shape = Shape(0, 0)
+ algebra: Algebra = Algebra.none
@dataclass(frozen=True)
@@ -24,6 +25,7 @@ class Const(Leaf, HasRepr):
value: Number
name: str = ""
shape: Shape = Shape.scalar()
+ algebra: Algebra = Algebra.none
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
r = repr_type()
@@ -44,6 +46,7 @@ class Var(Leaf, HasRepr):
"""
name: str
shape: Shape = Shape.scalar()
+ algebra: Algebra = Algebra.none
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
r = repr_type()
@@ -62,6 +65,7 @@ class Param(Leaf, HasRepr):
"""
name: str
shape: Shape = Shape.scalar()
+ algebra: Algebra = Algebra.none
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
if self not in state.parameters: