aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-04 01:01:15 +0100
committerNao Pross <np@0hm.ch>2024-03-04 01:01:15 +0100
commit87a72a65dd73f59c2ce5a1c54cc0b60c54bf8721 (patch)
tree0ed738deabc48f6004a32d4bc5459d45ca00f45a
parentImprove docstrings (diff)
downloadmdpoly-87a72a65dd73f59c2ce5a1c54cc0b60c54bf8721.tar.gz
mdpoly-87a72a65dd73f59c2ce5a1c54cc0b60c54bf8721.zip
Add shape checks for PolyRingAlgebra
-rw-r--r--mdpoly/abc.py7
-rw-r--r--mdpoly/algebra.py43
2 files changed, 35 insertions, 15 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index 4bc2afa..991cc04 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -6,6 +6,7 @@ from .util import iszero
from typing import Self, Sequence, Protocol, runtime_checkable
from abc import abstractmethod
+from functools import cached_property
@runtime_checkable
@@ -30,7 +31,11 @@ class Expr(Protocol):
""" Binary tree to represent a mathematical expression. """
left: Self | Leaf
right: Self | Leaf
- shape: Shape
+
+ @cached_property
+ @abstractmethod
+ def shape(self) -> Shape:
+ """ Computes the shape of the expression. """
def __repr__(self):
name = self.__class__.__qualname__
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 480a904..46a82b6 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -2,7 +2,7 @@
from .abc import Leaf, Expr, Repr
from .leaves import Nothing, Const, Param, Var
-from .errors import AlgebraicError
+from .errors import AlgebraicError, InvalidShape
from .state import State
from .types import Shape, MatrixIndex, PolyIndex
from .representations import HasRepr
@@ -36,10 +36,6 @@ class AlgebraicStructure(Protocol):
_parameter: type
_constant: type
- @property
- @abstractmethod
- def shape(self) -> Shape: ...
-
@classmethod
def _is_constant(cls: T, other: T) -> bool:
return isinstance(other, cls._constant)
@@ -162,6 +158,12 @@ def unary_operator(inner_type: AlgebraicStructure):
class Add(Expr, HasRepr):
""" Generic addition (no type check) """
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if self.left.shape != self.right.shape:
+ raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
+ return self.left.shape
+
def to_repr(self, repr_type: type, state: State) -> Repr:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
# Make a new empty representation
@@ -187,6 +189,12 @@ class Add(Expr, HasRepr):
class Sub(Expr, ReducibleExpr):
""" Generic subtraction operator (no type check) """
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if self.left.shape != self.right.shape:
+ raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
+ return self.left.shape
+
def reduce(self) -> Expr:
""" See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
return self.left + (-1 * self.right)
@@ -199,6 +207,12 @@ class Sub(Expr, ReducibleExpr):
class Mul(Expr, HasRepr):
""" Generic multiplication operator (no type check). """
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if self.left.shape != self.right.shape:
+ raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
+ return self.left.shape
+
def to_repr(self, repr_type: type, state: State) -> Repr:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
r = repr_type()
@@ -224,6 +238,12 @@ class Mul(Expr, HasRepr):
class Exp(Expr, ReducibleExpr):
""" Generic exponentiation (no type check). """
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if self.left.shape != self.right.shape:
+ raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
+ return self.left.shape
+
def reduce(self) -> Expr:
""" See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
var = self.left
@@ -248,10 +268,6 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
_parameter = Param
_constant = Const
- @property
- def shape(self):
- return Shape.scalar
-
def __add__(self, other):
other = self._wrap_if_constant(other)
return PolyAdd(self, other)
@@ -333,6 +349,10 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra):
def __init__(self, with_respect_to: Var):
self.wrt = with_respect_to
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ return self.inner.shape
+
def to_repr(self, repr_type: type, state: State) -> Repr:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
r = repr_type()
@@ -351,11 +371,6 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra):
return f"(∂_{self.wrt} {self.inner})"
-@unary_operator
-class Diff(Expr, PolyRingAlgebra):
- """ Total differentiation. """
-
-
# ┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓╻ ┏━┓╻ ┏━╸┏━╸┏┓ ┏━┓┏━┓
# ┣┳┛┣━┫ ┃ ┃┃ ┃┃┗┫┣━┫┃ ┣━┫┃ ┃╺┓┣╸ ┣┻┓┣┳┛┣━┫
# ╹┗╸╹ ╹ ╹ ╹┗━┛╹ ╹╹ ╹┗━╸ ╹ ╹┗━╸┗━┛┗━╸┗━┛╹┗╸╹ ╹