diff options
-rw-r--r-- | mdpoly/__init__.py | 22 | ||||
-rw-r--r-- | mdpoly/abc.py | 93 | ||||
-rw-r--r-- | mdpoly/algebra.py | 437 | ||||
-rw-r--r-- | mdpoly/leaves.py | 6 |
4 files changed, 265 insertions, 293 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py index 42a800d..fddd62b 100644 --- a/mdpoly/__init__.py +++ b/mdpoly/__init__.py @@ -5,8 +5,8 @@ from .abc import (Shape as _Shape) -from .algebra import (PolyRingAlgebra as _PolyRingAlgebra, - MatrixAlgebra as _MatrixAlgebra) +from .algebra import (PolyRingExpr as _PolyRingExpr, + MatrixExpr as _MatrixExpr) from .leaves import (Const as _Const, Var as _Var, @@ -14,7 +14,7 @@ from .leaves import (Const as _Const, from .state import State as _State -from typing import Self, Sequence +from typing import Self, Iterable # ┏━╸╻ ╻┏━┓┏━┓┏━┓╺┳╸ ┏━┓┏━┓ ╻┏━┓ @@ -32,34 +32,34 @@ State = _State # FIXME: move out of this file class _FromHelpers: @classmethod - def from_names(cls, comma_separated_names: str, strip: bool =True) -> Sequence[Self]: + def from_names(cls, comma_separated_names: str, strip: bool =True) -> Iterable[Self]: """ Generate scalar variables from comma separated list of names """ - names = comma_separated_names.split(",") + names: Iterable = comma_separated_names.split(",") if strip: names = map(str.strip, names) yield from map(cls, names) -class Constant(_Const, _PolyRingAlgebra, _FromHelpers): +class Constant(_Const, _PolyRingExpr, _FromHelpers): """ Constant values """ -class Variable(_Var, _PolyRingAlgebra, _FromHelpers): +class Variable(_Var, _PolyRingExpr, _FromHelpers): """ Polynomial Variable """ -class Parameter(_Param, _PolyRingAlgebra): +class Parameter(_Param, _PolyRingExpr): """ Parameter that can be substituted """ -class MatrixConstant(_Const, _MatrixAlgebra): +class MatrixConstant(_Const, _PolyRingExpr): """ Matrix constant """ -class MatrixVariable(_Var, _MatrixAlgebra): +class MatrixVariable(_Var, _MatrixExpr): """ Matrix Polynomial Variable """ -class MatrixParameter(_Param, _PolyRingAlgebra): +class MatrixParameter(_Param, _MatrixExpr): """ Matrix Parameter """ diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 6637db5..5d43055 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -1,36 +1,36 @@ """ Abstract Base Classes of MDPoly """ +from __future__ import annotations from .index import Number, Shape, MatrixIndex, PolyIndex from .constants import NUMERICS_EPS from .util import iszero -from typing import Self, Iterable, Sequence +from typing import Self, Any, Iterable, Sequence from enum import Enum, auto from abc import ABC, abstractmethod class Algebra(Enum): """ Types of algebras. """ + none = auto() poly_ring = auto() matrix_ring = auto() -class Leaf(ABC): - """ Leaf of the binary tree. """ - name: str - shape: Shape - - class Expr(ABC): """ Binary tree to represent a mathematical expression. """ @property + def is_leaf(self) -> bool: + return False + + @property @abstractmethod - def left(self) -> Self | Leaf: ... + def left(self) -> Self: ... @property @abstractmethod - def right(self) -> Self | Leaf: ... + def right(self) -> Self: ... @property @abstractmethod @@ -48,19 +48,19 @@ class Expr(ABC): name = self.__class__.__qualname__ return f"{name}(left={self.left}, right={self.right})" - def children(self) -> Sequence[Self | Leaf]: + def children(self) -> Sequence[Expr]: """ Iterate over the two nodes """ return self.left, self.right - def leaves(self) -> Iterable[Leaf]: + def leaves(self) -> Iterable[Expr]: """ Returns the leaves of the tree. This is done recursively and in :math:`\mathcal{O}(n)`.""" - if isinstance(self.left, Leaf): + if self.left.is_leaf: yield self.left else: yield from self.left.leaves() - if isinstance(self.right, Leaf): + if self.right.is_leaf: yield self.right else: yield from self.right.leaves() @@ -109,9 +109,74 @@ class Expr(ABC): return replace_all(self) - def __iter__(self) -> Iterable[Self | Leaf]: + def __iter__(self) -> Iterable[Expr]: yield from self.children() + # --- Operator overloading --- + + @staticmethod + def _wrap(if_type: type, wrapper_type: type, obj): + # Do not wrap if is alreay an expression + if isinstance(obj, Expr): + return obj + + if not isinstance(obj, if_type): + raise TypeError + + return wrapper_type(obj) + + def __add__(self, other: Any) -> Self: + raise NotImplementedError + + def __radd__(self, other: Any) -> Self: + raise NotImplementedError + + def __sub__(self, other: Any) -> Self: + raise NotImplementedError + + def __rsub__(self, other: Any) -> Self: + raise NotImplementedError + + def __neg__(self) -> Self: + raise NotImplementedError + + def __mul__(self, other: Any) -> Self: + raise NotImplementedError + + def __rmul__(self, other: Any) -> Self: + raise NotImplementedError + + def __pow__(self, other: Any) -> Self: + raise NotImplementedError + + def __matmul__(self, other: Any) -> Self: + raise NotImplementedError + + def __rmatmul__(self, other: Any) -> Self: + raise NotImplementedError + + def __truediv__(self, other: Any) -> Self: + raise NotImplementedError + + def __rtruediv__(self, other: Any) -> Self: + raise NotImplementedError + + +class Leaf(Expr): + """ Leaf of the binary tree. """ + name: str + + @property + def is_leaf(self): + return True + + @property + def left(self): + ... + + @property + def right(self): + ... class Rel(ABC): """ Relation between two expressions. """ diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index a07c4d6..aa12ae3 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -1,79 +1,60 @@ """ Algebraic Structures for Expressions """ +from __future__ import annotations -from .abc import Leaf, Expr, Repr +from .abc import Algebra, Expr, Repr from .leaves import Nothing, Const, Param, Var from .errors import AlgebraicError, InvalidShape from .state import State -from .index import Shape, MatrixIndex, PolyIndex +from .index import Shape, MatrixIndex, PolyIndex, Number from .representations import HasRepr -from typing import Protocol, Self, Any, runtime_checkable -from functools import wraps, reduce, cached_property +from typing import cast, Protocol, Self, Sequence +from functools import reduce from itertools import product -from enum import Enum from abc import abstractmethod import operator -@runtime_checkable -class AlgebraicStructure(Protocol): - """ Provides methods to enforce algebraic closure of operations. +# To make the type checker happy +class ExprWithRepr(Expr, HasRepr): - Internally, this is done by adding a class variable ``_algebra`` of type - :py:class:`mdpoly.algebra.AlgebraicStructure.Algebra`, and checking that - they match in the methods for operator overloading. - """ - - class Algebra(Enum): - """ Types of algebras. """ - poly_ring = 0 - matrix_ring = 1 - - _algebra: Algebra - _parameter: type - _constant: type + def children(self) -> Sequence[Self]: # type: ignore[override] + return cast(Sequence[Self], Expr.children(self)) - @classmethod - def _is_constant(cls, other: Self) -> bool: - return isinstance(other, cls._constant) - @classmethod - def _is_parameter(cls, other: Self) -> bool: - return isinstance(other, cls._parameter) +class BinaryOp(ExprWithRepr): + def __init__(self, left, right): + self._left = left + self._right = right - @classmethod - def _is_const_or_param(cls, other: Self) -> bool: - return cls._is_constant(other) or cls._is_parameter(other) + @property + def left(self) -> ExprWithRepr: # type: ignore[override] + return self._left - @classmethod - def _assert_same_algebra(cls, other: Self) -> None: - if not cls._algebra == other._algebra: - if not cls._is_constant(other): - raise AlgebraicError("Cannot perform operation between types from " - f"different algebraic structures {cls} ({cls._algebra}) " - f"and {type(other)} ({other._algebra})") + @property + def right(self) -> ExprWithRepr: # type: ignore[override] + return self._right - @classmethod - def _wrap_if_constant(cls, other: Any): - if isinstance(other, AlgebraicStructure): - cls._assert_same_algebra(other) - return other - if isinstance(other, cls._constant): - return other +class UnaryOp(ExprWithRepr): + def __init__(self, inner): + self._inner = inner - return cls._constant(name="", value=other) + @property + def inner(self) -> Expr: + return self._inner + + @property + def left(self) -> ExprWithRepr: # type: ignore[override] + return self._inner - # FIXME: add shape rules in algebra - # @classmethod - # def _assert_same_shape(cls: T, other: T) -> None: - # if not cls.shape == other.shape: - # raise AlgebraicError("Cannot perform algebraic operations between " - # f"objects with different shapes {cls.shape} and {other.shape}") + @property + def right(self) -> ExprWithRepr: # type: ignore[override] + return Nothing() # type: ignore -class ReducibleExpr(HasRepr, Protocol): +class Reducible(HasRepr, Protocol): """ Reducible Expression Algebraic expression that can be written in terms of other (existing) @@ -91,80 +72,83 @@ class ReducibleExpr(HasRepr, Protocol): """ Reduce the expression to its basic elements """ -# FIXME: The type checker really hates this trick, what is a better solution? -def binary_operator(left_type: AlgebraicStructure, right_type: AlgebraicStructure): - """ Class decorator that adds constructor for binary operations of Expr. - - Classes that inherit :py:class:`mdpoly.abc.Expr` take values for *left* and - *right* to represent the operands. This binary operator specifies the - algebra (:py:class:`mdpoly.algebra.AlgebraicStructure`) that *left* and - *right* have to respect in order for the operation to be correct. - Concretely, this is to raise and exception for eg. when a scalar is added - to a matrix, etc. - """ - # TODO: add right_shape and left_shape for matrices - def decorator(cls: Expr) -> Expr: - init_cls = cls.__init__ - @wraps(cls.__init__) - def new_init_cls(self, left, right, *args, **kwargs): - init_cls(self, *args, **kwargs) - # Wrong algebra - if not isinstance(left, left_type) or not isinstance(right, right_type): - # None of the two is a Leaf. This is a workaround because - # adding the algebra to the base Leaf types Const, Var, etc. is - # not possible without having circular imports. For the exported types - # Constant, Variable, etc. this is not a problem. - if not isinstance(left, Leaf) and not isinstance(right, Leaf): - raise AlgebraicError( - "Cannot perform operation between types from " - f"different algebraic structures {type(left)} ({left._algebra}) " - f"and {type(right)} ({right._algebra})") - - self.left, self.right = left, right - cls.__init__ = new_init_cls - return cls - return decorator - - -# FIXME: same as binary_operator -def unary_operator(inner_type: AlgebraicStructure): - """ Class decorator that adss constructor for unary operations of Expr - - This is the special case of the binary operator, where only *left* is used - (because it is assumed that the operator acts from the left. The field for - *right* contains a placeholder :py:class:`mdpoly.leaves.Nothing`, - furthermore a new field *inner* is added as alias to *left*. - - See also :py:func:`mdpoly.algebra.binary_operator`. - """ - def decorator(cls: Expr) -> Expr: - """ Class decorator to modify constructor of unary operators. """ - init_cls = cls.__init__ - @wraps(cls.__init__) - def new_init_cls(self, left, *args, **kwargs): - init_cls(self, *args, **kwargs) - self.left, self.right = left, Nothing - - def inner(self): - return self.left +# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓ +# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫ +# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ - cls.__init__ = new_init_cls - cls.inner = property(inner) - return cls - return decorator +class PolyRingExpr(Expr): + r""" Endows with the algebraic structure of a polynomial ring. + This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the + polynomials are scalars. + """ -class Add(Expr, HasRepr): - """ Generic addition (no type check) """ + @property + def algebra(self): + return @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ if self.left.shape != self.right.shape: - raise InvalidShape(f"Cannot add {self.left} and {self.right} with " + raise InvalidShape(f"Cannot perform operation {repr(self)} with " f"shapes {self.left.shape} and {self.right.shape}.") return self.left.shape + def __add__(self, other): + other = self._wrap(Number, Const, other) + return PolyAdd(self, other) + + def __radd__(self, other): + other = self._wrap(Number, Const, other) + return PolyAdd(other, self) + + def __sub__(self, other): + other = self._wrap(Number, Const, other) + return PolySub(self, other) + + def __rsub__(self, other): + other = self._wrap(Number, Const, other) + return PolyAdd(other, self) + + def __neg__(self): + return PolyMul(self._constant(-1), self) + + def __mul__(self, other): + other = self._wrap(Number, Const, other) + return PolyMul(self, other) + + def __rmul__(self, other): + other = self._wrap(Number, Const, other) + return PolyMul(other, self) + + def __matmul__(self, other): + raise AlgebraicError("Cannot perform matrix multiplication in polynomial ring (they are scalars).") + + def __rmatmul__(self, other): + self.__rmatmul__(other) + + def __truediv__(self, other): + other = self._wrap(Number, Const, other) + if not self._is_const_or_param(other): + raise AlgebraicError("Cannot divide by variables in polynomial ring.") + + return PolyMul(self, other) + + def __rtruediv__(self, other): + raise AlgebraicError("Cannot perform right division in polynomial ring.") + + def __pow__(self, other): + other = self._wrap(Number, Const, other) + if not isinstance(other, Const | Param): + raise AlgebraicError(f"Cannot raise to powers of type {type(other)} in " + "polynomial ring. Only constants and parameters are allowed.") + return PolyExp(self, other) + + +class PolyAdd(BinaryOp, PolyRingExpr): + """ Addition operator between scalar polynomials. """ + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ # Make a new empty representation @@ -186,34 +170,19 @@ class Add(Expr, HasRepr): return f"({self.left} + {self.right})" -class Sub(ReducibleExpr): - """ Generic subtraction operator (no type check) """ - - @property - def shape(self) -> Shape: - """ See :py:meth:`mdpoly.abc.Expr.shape`. """ - if self.left.shape != self.right.shape: - raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") - return self.left.shape +class PolySub(BinaryOp, PolyRingExpr, Reducible): + """ Subtraction operator between scalar polynomials. """ def reduce(self) -> HasRepr: """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ - # return self.left + (-1 * self.right) - return Add(self.left, Mul(Const(value=-1), self.right)) + return self.left + (-1 * self.right) def __repr__(self) -> str: return f"({self.left} - {self.right})" -class Mul(Expr, HasRepr): - """ Generic multiplication operator (no type check). """ - - @property - def shape(self) -> Shape: - """ See :py:meth:`mdpoly.abc.Expr.shape`. """ - if self.left.shape != self.right.shape: - raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") - return self.left.shape +class PolyMul(BinaryOp, PolyRingExpr): + """ Multiplication operator between scalar polynomials. """ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ @@ -241,19 +210,23 @@ class Mul(Expr, HasRepr): return f"({self.left} * {self.right})" -class Exp(ReducibleExpr): +class PolyExp(BinaryOp, PolyRingExpr, Reducible): """ Generic exponentiation (no type check). """ @property - def shape(self) -> Shape: - """ See :py:meth:`mdpoly.abc.Expr.shape`. """ - if self.left.shape != self.right.shape: - raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") - return self.left.shape + def right(self) -> Const: # type: ignore[override] + if not isinstance(super().right, Const): + raise AlgebraicError(f"Cannot raise {self.left} to {self.right} because" + f"{self.right} is not a constant.") + + return cast(Const, super().right) def reduce(self) -> HasRepr: """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ var = self.left + if not isinstance(self.right.value, int): + raise NotImplementedError + ntimes = self.right.value - 1 return reduce(operator.mul, (var for _ in range(ntimes)), var) @@ -261,99 +234,10 @@ class Exp(ReducibleExpr): return f"({self.left} ** {self.right})" -# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓ -# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫ -# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ - -class PolyRingAlgebra(AlgebraicStructure, Protocol): - r""" Endows with the algebraic structure of a polynomial ring. - - This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the - polynomials are scalars. - """ - _algebra: AlgebraicStructure.Algebra = AlgebraicStructure.Algebra.poly_ring - _parameter: type = Param - _constant: type = Const - - def __add__(self, other): - other = self._wrap_if_constant(other) - return PolyAdd(self, other) - - def __radd__(self, other): - other = self._wrap_if_constant(other) - return PolyAdd(other, self) - - def __sub__(self, other): - other = self._wrap_if_constant(other) - return PolySub(self, other) - - def __rsub__(self, other): - other = self._wrap_if_constant(other) - return PolyAdd(other, self) - - def __neg__(self, other): - return PolyMul(self._constant(-1), other) - - def __mul__(self, other): - other = self._wrap_if_constant(other) - return PolyMul(self, other) - - def __rmul__(self, other): - other = self._wrap_if_constant(other) - return PolyMul(other, self) - - def __truediv__(self, other): - other = self._wrap_if_constant(other) - if not self._is_const_or_param(other): - raise AlgebraicError("Cannot divide by variables in polynomial ring.") - - return PolyMul(self, other) - - def __pow__(self, other): - other = self._wrap_if_constant(other) - if not self._is_const_or_param(other): - raise AlgebraicError(f"Cannot raise to powers of type {type(other)} in " - "polynomial ring. Only constants and parameters are allowed.") - return PolyExp(self, other) - - -@binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class PolyAdd(Add, PolyRingAlgebra): - """ Addition operator between scalar polynomials. """ - - -@binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class PolySub(Sub, PolyRingAlgebra): - """ Subtraction operator between scalar polynomials. """ - - -@binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class PolyMul(Mul, PolyRingAlgebra): - """ Multiplication operator between scalar polynomials. """ - - -@binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class PolyDiv(ReducibleExpr, PolyRingAlgebra): - """ Division of scalar polynomial by scalar. """ - - def reduce(self) -> HasRepr: - """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ - inverse = self._constant(value=1/self.right.value) - return PolyMul(inverse, self.left) - - def __repr__(self) -> str: - return f"({self.left} / {self.right})" - - -@binary_operator(PolyRingAlgebra, PolyRingAlgebra) -class PolyExp(Exp, PolyRingAlgebra): - """ Exponentiation operator of scalar polynomials. """ - - -@unary_operator(PolyRingAlgebra) -class PartialDiff(Expr, HasRepr, PolyRingAlgebra): +class PolyPartialDiff(UnaryOp, PolyRingExpr): """ Partial differentiation of scalar polynomials """ - def __init__(self, with_respect_to: Var): + def __init__(self, inner: Expr, with_respect_to: Var): + UnaryOp.__init__(self, inner) self.wrt = with_respect_to @property @@ -384,7 +268,7 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra): # ╹┗╸╹ ╹ ╹ ╹┗━┛╹ ╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ -class RationalFieldAlgebra(AlgebraicStructure, Protocol): +class RationalFieldExpr(Expr): """ Endows with the algebraic structure of a field of rational functions """ @@ -393,7 +277,7 @@ class RationalFieldAlgebra(AlgebraicStructure, Protocol): # ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫ # ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ -class MatrixAlgebra(AlgebraicStructure, Protocol): +class MatrixExpr(Expr): r""" Endows with the algebraic structure of a matrix ring and / or module depending on the shape. @@ -411,37 +295,36 @@ class MatrixAlgebra(AlgebraicStructure, Protocol): already included (eg. transposition). """ - _algebra: AlgebraicStructure.Algebra = AlgebraicStructure.Algebra.matrix_ring - # FIXME: consider MatParam or something like that? - _parameter: type = Param - _constant: type = Const + @property + def algebra(self): + return Algebra.matrix_algebra def __add__(self, other): - other = self._wrap_if_constant(other) + other = self._wrap(Number, Const, other) return MatAdd(self, other) def __sub__(self, other): - other = self._wrap_if_constant(other) + other = self._wrap(Number, Const, other) return MatSub(self, other) def __rsub__(self, other): - other = self._wrap_if_constant(other) + other = self._wrap(Number, Const, other) return MatSub(other, self) def __mul__(self, other): - other = self._wrap_if_constant(other) + other = self._wrap(Number, Const, other) return MatScalarMul(other, self) def __rmul__(self, other): - scalar = self._wrap_if_constant(other) - return MatScalarMul(scalar, self) + other = self._wrap(Number, Const, other) + return MatScalarMul(other, self) def __matmul__(self, other): - other = self._wrap_if_constant(other) + other = self._wrap(Number, Const, other) return MatMul(self, other) def __rmatmul(self, other): - other = self._wrap_if_constant(other) + other = self._wrap(Number, Const, other) return MatMul(other, self) def __truediv__(self, scalar): @@ -457,31 +340,54 @@ class MatrixAlgebra(AlgebraicStructure, Protocol): """ Shorthand for :py:meth:`mdpoly.algebra.MatrixAlgebra.transpose`. """ return self.transpose() - # def to_scalar(self, scalar_type: type): - # """ Convert to a scalar expression. """ - # raise NotImplementedError + def to_scalar(self, scalar_type: type): + """ Convert to a scalar expression. """ + raise NotImplementedError + + +class MatAdd(BinaryOp, PolyRingExpr): + """ Addition operator between matrices. """ + + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ + # Make a new empty representation + r = repr_type() + + # Make representations for existing stuff + for node in self.children(): + nrepr, state = node.to_repr(repr_type, state) + + # Process non-zero terms + for entry in nrepr.entries(): + for term in nrepr.terms(entry): + s = r.at(entry, term) + nrepr.at(entry, term) + r.set(entry, term, s) + + return r, state + + def __repr__(self) -> str: + return f"({self.left} + {self.right})" -@binary_operator(MatrixAlgebra, MatrixAlgebra) -class MatAdd(Add, MatrixAlgebra): - """ Matrix Addition. """ +class MatSub(BinaryOp, PolyRingExpr, Reducible): + """ Subtraction operator between matrices. """ + def reduce(self) -> HasRepr: + """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ + return self.left + (-1 * self.right) -@binary_operator(MatrixAlgebra, MatrixAlgebra) -class MatSub(Sub, MatrixAlgebra): - """ Matrix Subtraction. """ + def __repr__(self) -> str: + return f"({self.left} - {self.right})" -@binary_operator(MatrixAlgebra, MatrixAlgebra) -class MatElemMul(Mul, MatrixAlgebra): +class MatElemMul(BinaryOp, MatrixExpr): """ Elementwise Matrix Multiplication. """ def __repr__(self) -> str: return f"({self.left} .* {self.right})" -@binary_operator(MatrixAlgebra, MatrixAlgebra) -class MatScalarMul(ReducibleExpr, MatrixAlgebra): +class MatScalarMul(BinaryOp, MatrixExpr): """ Matrix-Scalar Multiplication. """ @property @@ -496,8 +402,7 @@ class MatScalarMul(ReducibleExpr, MatrixAlgebra): raise InvalidShape(f"Either {self.left} or {self.right} must be a scalar.") -@binary_operator(MatrixAlgebra, MatrixAlgebra) -class MatDotProd(Expr, MatrixAlgebra): +class MatDotProd(BinaryOp, MatrixExpr): """ Dot product. """ @property @@ -514,8 +419,7 @@ class MatDotProd(Expr, MatrixAlgebra): return Shape.scalar() -@binary_operator(MatrixAlgebra, MatrixAlgebra) -class MatMul(Expr, MatrixAlgebra): +class MatMul(BinaryOp, MatrixExpr): """ Matrix Multiplication. """ @property @@ -532,11 +436,10 @@ class MatMul(Expr, MatrixAlgebra): return f"({self.left} @ {self.right})" -@unary_operator(MatrixAlgebra) -class MatTranspose(Expr, MatrixAlgebra): +class MatTranspose(UnaryOp, MatrixExpr): """ Matrix transposition """ @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ - return Shape(self.inner.cols, self.inner.rows) + return Shape(self.inner.shape.cols, self.inner.shape.rows) diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py index af43435..83324a3 100644 --- a/mdpoly/leaves.py +++ b/mdpoly/leaves.py @@ -1,4 +1,4 @@ -from .abc import Leaf, Repr +from .abc import Algebra, Leaf, Repr from .index import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex from .state import State from .errors import MissingParameters @@ -14,6 +14,7 @@ class Nothing(Leaf): """ name: str = "<nothing>" shape: Shape = Shape(0, 0) + algebra: Algebra = Algebra.none @dataclass(frozen=True) @@ -24,6 +25,7 @@ class Const(Leaf, HasRepr): value: Number name: str = "" shape: Shape = Shape.scalar() + algebra: Algebra = Algebra.none def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: r = repr_type() @@ -44,6 +46,7 @@ class Var(Leaf, HasRepr): """ name: str shape: Shape = Shape.scalar() + algebra: Algebra = Algebra.none def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: r = repr_type() @@ -62,6 +65,7 @@ class Param(Leaf, HasRepr): """ name: str shape: Shape = Shape.scalar() + algebra: Algebra = Algebra.none def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: if self not in state.parameters: |