aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-19 03:38:52 +0100
committerNao Pross <np@0hm.ch>2024-03-19 03:38:52 +0100
commitb6ac9461da70b0d3cf053caaca3ec03309b7da91 (patch)
tree338b0c92e2a8fd5289039f9110f10da0f85f49ee
parentAdd helper method WithOps.to_sparse to give sparse representation (diff)
downloadmdpoly-b6ac9461da70b0d3cf053caaca3ec03309b7da91.tar.gz
mdpoly-b6ac9461da70b0d3cf053caaca3ec03309b7da91.zip
Improve error messages at evaluation time
-rw-r--r--examples/error_chain.py26
-rw-r--r--mdpoly/expressions.py65
-rw-r--r--mdpoly/operations/__init__.py20
-rw-r--r--mdpoly/operations/exp.py7
4 files changed, 113 insertions, 5 deletions
diff --git a/examples/error_chain.py b/examples/error_chain.py
new file mode 100644
index 0000000..31a75ea
--- /dev/null
+++ b/examples/error_chain.py
@@ -0,0 +1,26 @@
+try:
+ import mdpoly
+except ModuleNotFoundError:
+ import sys
+ import pathlib
+ parent = pathlib.Path(__file__).resolve().parent.parent
+ sys.path.append(str(parent))
+
+from mdpoly import State, Variable, Parameter
+from mdpoly.representations import SparseRepr
+
+def make_invalid_expr(x, y):
+ return x ** (-4)
+
+def make_another_invalid(x, y):
+ return x ** -1
+
+def make_ok_expr(x, y):
+ w = make_invalid_expr(x, y)
+ z = make_another_invalid(x, y)
+ return x + z + y + w
+
+x, y, z = Variable.from_names("x, y, z")
+
+bad = make_ok_expr(x, y)
+bad.to_sparse(State())
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
index 8836c52..4600660 100644
--- a/mdpoly/expressions.py
+++ b/mdpoly/expressions.py
@@ -11,6 +11,7 @@ from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number
from .errors import MissingParameters, AlgebraicError
from .representations import SparseRepr
+from .operations import WithTraceback
from .operations.add import MatAdd, MatSub
from .operations.mul import MatElemMul
from .operations.exp import PolyExp
@@ -164,7 +165,7 @@ class WithOps:
# TODO: not very idiomatic, how much is this unidiomatic?
return self.unwrap()
- def __exit__(self, *ex):
+ def __exit__(self, exc_type: type, exc_val: Exception, tb):
""" Context manager exit.
See :py:meth:`mdpoly.expresssions.WithOps.__enter__`.
@@ -333,12 +334,72 @@ class WithOps:
# -- Make representations ---
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ """ Construct a representation. """
+ # Code below is to improve error messages when an exception occurrs
+ # during lazy evaluation.
+
+ def raises(expr: Expr):
+ """ Check if an expression raises an exception. """
+ try:
+ _ = expr.to_repr(repr_type, state)
+ return False
+ except Exception:
+ return True
+
+ def find_bad(expr):
+ """ binary search which expression is bad """
+ left, right = raises(expr.left), raises(expr.right)
+ if not left and not right:
+ # left and right are good, exception occurred exactly here
+ return expr
+
+ elif left and not right:
+ # exception is caused by left
+ return find_bad(expr.left)
+
+ elif right and not left:
+ # exception is caused by right
+ return find_bad(expr.right)
+
+ # Multiple invalid expressions!
+ # We pick one, the other will appear next time.
+ return find_bad(expr.left)
+
+ try:
+ # try to evaluate exception
+ rep, state = self.unwrap().to_repr(repr_type, state)
+ return rep, state
+
+ except Exception as e:
+ bad = find_bad(self)
+ # for frame in bad._stack:
+ # print(bad, frame.function, frame.filename, frame.lineno)
+
+ import traceback
+ stackstr = "".join(traceback.format_stack(bad._stack[0].frame))
+ raise RuntimeError(f"Expression {str(bad)} caused the previous exception. "
+ "Below is the call stack at the time when the bad expression "
+ "object was constructed: \n\n" + stackstr) from e
+
+ finally:
+ def clear_stack(expr: Expr):
+ """ Recursively clear stack objects to free references. """
+ if isinstance(expr, WithTraceback):
+ expr._stack.clear()
+
+ if not expr.is_leaf:
+ clear_stack(expr.left)
+ clear_stack(expr.right)
+
+ clear_stack(self.expr)
+
def to_sparse(self, state: State) -> tuple[SparseRepr, State]:
""" Make a sparse representation.
See :py:class:`mdpoly.representations.SparseRepr`
"""
- return self.unwrap().to_repr(SparseRepr, state)
+ return self.to_repr(SparseRepr, state)
def to_dense(self, state: State) -> tuple[Repr, State]:
raise NotImplementedError
diff --git a/mdpoly/operations/__init__.py b/mdpoly/operations/__init__.py
index c75608d..cd95f45 100644
--- a/mdpoly/operations/__init__.py
+++ b/mdpoly/operations/__init__.py
@@ -3,8 +3,9 @@ from typing import TYPE_CHECKING
from typing import Type
from abc import abstractmethod
-from dataclasses import field
+from dataclasses import dataclass, field
from dataclassabc import dataclassabc
+from inspect import stack
from ..abc import ReprT, Expr, Nothing
@@ -12,8 +13,20 @@ if TYPE_CHECKING:
from ..state import State
+@dataclass
+class WithTraceback:
+ """ Dataclass that contains traceback information.
+
+ See where it is used in :py:meth:`mdpoly.expressions.WithOps.to_repr`.
+ """
+ def __post_init__(self):
+ # NOTE: leaving this information around indefinitely can cause memory
+ # leaks, it must be cleared by calling `self._stack.clear()`.
+ self._stack = stack()
+
+
@dataclassabc(eq=False)
-class BinaryOp(Expr):
+class BinaryOp(Expr, WithTraceback):
""" Binary operator.
TODO: desc
@@ -23,7 +36,7 @@ class BinaryOp(Expr):
@dataclassabc(eq=False)
-class UnaryOp(Expr):
+class UnaryOp(Expr, WithTraceback):
""" Unary operator.
TODO: desc
@@ -72,3 +85,4 @@ class Reducible(Expr):
# @abstractmethod
# def simplify(self) -> Expr:
# """ Simplify the expression """
+
diff --git a/mdpoly/operations/exp.py b/mdpoly/operations/exp.py
index 2b17a07..ee90928 100644
--- a/mdpoly/operations/exp.py
+++ b/mdpoly/operations/exp.py
@@ -46,6 +46,13 @@ class PolyExp(BinaryOp, Reducible):
if not isinstance(self.right.value, int):
raise NotImplementedError("Cannot raise to non-integer powers (yet).")
+ if self.right.value == 0:
+ # FIXME: return constant
+ raise NotImplementedError
+
+ elif self.right.value < 0:
+ raise AlgebraicError("Cannot raise to negative powers (yet).")
+
ntimes = self.right.value - 1
return reduce(PolyMul, (var for _ in range(ntimes)), var)