diff options
author | Nao Pross <np@0hm.ch> | 2024-03-04 01:01:15 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-04 01:01:15 +0100 |
commit | 87a72a65dd73f59c2ce5a1c54cc0b60c54bf8721 (patch) | |
tree | 0ed738deabc48f6004a32d4bc5459d45ca00f45a | |
parent | Improve docstrings (diff) | |
download | mdpoly-87a72a65dd73f59c2ce5a1c54cc0b60c54bf8721.tar.gz mdpoly-87a72a65dd73f59c2ce5a1c54cc0b60c54bf8721.zip |
Add shape checks for PolyRingAlgebra
Diffstat (limited to '')
-rw-r--r-- | mdpoly/abc.py | 7 | ||||
-rw-r--r-- | mdpoly/algebra.py | 43 |
2 files changed, 35 insertions, 15 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 4bc2afa..991cc04 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -6,6 +6,7 @@ from .util import iszero from typing import Self, Sequence, Protocol, runtime_checkable from abc import abstractmethod +from functools import cached_property @runtime_checkable @@ -30,7 +31,11 @@ class Expr(Protocol): """ Binary tree to represent a mathematical expression. """ left: Self | Leaf right: Self | Leaf - shape: Shape + + @cached_property + @abstractmethod + def shape(self) -> Shape: + """ Computes the shape of the expression. """ def __repr__(self): name = self.__class__.__qualname__ diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 480a904..46a82b6 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -2,7 +2,7 @@ from .abc import Leaf, Expr, Repr from .leaves import Nothing, Const, Param, Var -from .errors import AlgebraicError +from .errors import AlgebraicError, InvalidShape from .state import State from .types import Shape, MatrixIndex, PolyIndex from .representations import HasRepr @@ -36,10 +36,6 @@ class AlgebraicStructure(Protocol): _parameter: type _constant: type - @property - @abstractmethod - def shape(self) -> Shape: ... - @classmethod def _is_constant(cls: T, other: T) -> bool: return isinstance(other, cls._constant) @@ -162,6 +158,12 @@ def unary_operator(inner_type: AlgebraicStructure): class Add(Expr, HasRepr): """ Generic addition (no type check) """ + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + if self.left.shape != self.right.shape: + raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") + return self.left.shape + def to_repr(self, repr_type: type, state: State) -> Repr: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ # Make a new empty representation @@ -187,6 +189,12 @@ class Add(Expr, HasRepr): class Sub(Expr, ReducibleExpr): """ Generic subtraction operator (no type check) """ + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + if self.left.shape != self.right.shape: + raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") + return self.left.shape + def reduce(self) -> Expr: """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ return self.left + (-1 * self.right) @@ -199,6 +207,12 @@ class Sub(Expr, ReducibleExpr): class Mul(Expr, HasRepr): """ Generic multiplication operator (no type check). """ + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + if self.left.shape != self.right.shape: + raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") + return self.left.shape + def to_repr(self, repr_type: type, state: State) -> Repr: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ r = repr_type() @@ -224,6 +238,12 @@ class Mul(Expr, HasRepr): class Exp(Expr, ReducibleExpr): """ Generic exponentiation (no type check). """ + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + if self.left.shape != self.right.shape: + raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") + return self.left.shape + def reduce(self) -> Expr: """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ var = self.left @@ -248,10 +268,6 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol): _parameter = Param _constant = Const - @property - def shape(self): - return Shape.scalar - def __add__(self, other): other = self._wrap_if_constant(other) return PolyAdd(self, other) @@ -333,6 +349,10 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra): def __init__(self, with_respect_to: Var): self.wrt = with_respect_to + def shape(self) -> Shape: + """ See :py:meth:`mdpoly.abc.Expr.shape`. """ + return self.inner.shape + def to_repr(self, repr_type: type, state: State) -> Repr: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ r = repr_type() @@ -351,11 +371,6 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra): return f"(∂_{self.wrt} {self.inner})" -@unary_operator -class Diff(Expr, PolyRingAlgebra): - """ Total differentiation. """ - - # ┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓ # ┣┳┛┣━┫ ┃ ┃┃ ┃┃┗┫┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫ # ╹┗╸╹ ╹ ╹ ╹┗━┛╹ ╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹ |