diff options
author | Nao Pross <np@0hm.ch> | 2024-03-19 03:38:52 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-19 03:38:52 +0100 |
commit | b6ac9461da70b0d3cf053caaca3ec03309b7da91 (patch) | |
tree | 338b0c92e2a8fd5289039f9110f10da0f85f49ee | |
parent | Add helper method WithOps.to_sparse to give sparse representation (diff) | |
download | mdpoly-b6ac9461da70b0d3cf053caaca3ec03309b7da91.tar.gz mdpoly-b6ac9461da70b0d3cf053caaca3ec03309b7da91.zip |
Improve error messages at evaluation time
Diffstat (limited to '')
-rw-r--r-- | examples/error_chain.py | 26 | ||||
-rw-r--r-- | mdpoly/expressions.py | 65 | ||||
-rw-r--r-- | mdpoly/operations/__init__.py | 20 | ||||
-rw-r--r-- | mdpoly/operations/exp.py | 7 |
4 files changed, 113 insertions, 5 deletions
diff --git a/examples/error_chain.py b/examples/error_chain.py new file mode 100644 index 0000000..31a75ea --- /dev/null +++ b/examples/error_chain.py @@ -0,0 +1,26 @@ +try: + import mdpoly +except ModuleNotFoundError: + import sys + import pathlib + parent = pathlib.Path(__file__).resolve().parent.parent + sys.path.append(str(parent)) + +from mdpoly import State, Variable, Parameter +from mdpoly.representations import SparseRepr + +def make_invalid_expr(x, y): + return x ** (-4) + +def make_another_invalid(x, y): + return x ** -1 + +def make_ok_expr(x, y): + w = make_invalid_expr(x, y) + z = make_another_invalid(x, y) + return x + z + y + w + +x, y, z = Variable.from_names("x, y, z") + +bad = make_ok_expr(x, y) +bad.to_sparse(State()) diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py index 8836c52..4600660 100644 --- a/mdpoly/expressions.py +++ b/mdpoly/expressions.py @@ -11,6 +11,7 @@ from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number from .errors import MissingParameters, AlgebraicError from .representations import SparseRepr +from .operations import WithTraceback from .operations.add import MatAdd, MatSub from .operations.mul import MatElemMul from .operations.exp import PolyExp @@ -164,7 +165,7 @@ class WithOps: # TODO: not very idiomatic, how much is this unidiomatic? return self.unwrap() - def __exit__(self, *ex): + def __exit__(self, exc_type: type, exc_val: Exception, tb): """ Context manager exit. See :py:meth:`mdpoly.expresssions.WithOps.__enter__`. @@ -333,12 +334,72 @@ class WithOps: # -- Make representations --- + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + """ Construct a representation. """ + # Code below is to improve error messages when an exception occurrs + # during lazy evaluation. + + def raises(expr: Expr): + """ Check if an expression raises an exception. """ + try: + _ = expr.to_repr(repr_type, state) + return False + except Exception: + return True + + def find_bad(expr): + """ binary search which expression is bad """ + left, right = raises(expr.left), raises(expr.right) + if not left and not right: + # left and right are good, exception occurred exactly here + return expr + + elif left and not right: + # exception is caused by left + return find_bad(expr.left) + + elif right and not left: + # exception is caused by right + return find_bad(expr.right) + + # Multiple invalid expressions! + # We pick one, the other will appear next time. + return find_bad(expr.left) + + try: + # try to evaluate exception + rep, state = self.unwrap().to_repr(repr_type, state) + return rep, state + + except Exception as e: + bad = find_bad(self) + # for frame in bad._stack: + # print(bad, frame.function, frame.filename, frame.lineno) + + import traceback + stackstr = "".join(traceback.format_stack(bad._stack[0].frame)) + raise RuntimeError(f"Expression {str(bad)} caused the previous exception. " + "Below is the call stack at the time when the bad expression " + "object was constructed: \n\n" + stackstr) from e + + finally: + def clear_stack(expr: Expr): + """ Recursively clear stack objects to free references. """ + if isinstance(expr, WithTraceback): + expr._stack.clear() + + if not expr.is_leaf: + clear_stack(expr.left) + clear_stack(expr.right) + + clear_stack(self.expr) + def to_sparse(self, state: State) -> tuple[SparseRepr, State]: """ Make a sparse representation. See :py:class:`mdpoly.representations.SparseRepr` """ - return self.unwrap().to_repr(SparseRepr, state) + return self.to_repr(SparseRepr, state) def to_dense(self, state: State) -> tuple[Repr, State]: raise NotImplementedError diff --git a/mdpoly/operations/__init__.py b/mdpoly/operations/__init__.py index c75608d..cd95f45 100644 --- a/mdpoly/operations/__init__.py +++ b/mdpoly/operations/__init__.py @@ -3,8 +3,9 @@ from typing import TYPE_CHECKING from typing import Type from abc import abstractmethod -from dataclasses import field +from dataclasses import dataclass, field from dataclassabc import dataclassabc +from inspect import stack from ..abc import ReprT, Expr, Nothing @@ -12,8 +13,20 @@ if TYPE_CHECKING: from ..state import State +@dataclass +class WithTraceback: + """ Dataclass that contains traceback information. + + See where it is used in :py:meth:`mdpoly.expressions.WithOps.to_repr`. + """ + def __post_init__(self): + # NOTE: leaving this information around indefinitely can cause memory + # leaks, it must be cleared by calling `self._stack.clear()`. + self._stack = stack() + + @dataclassabc(eq=False) -class BinaryOp(Expr): +class BinaryOp(Expr, WithTraceback): """ Binary operator. TODO: desc @@ -23,7 +36,7 @@ class BinaryOp(Expr): @dataclassabc(eq=False) -class UnaryOp(Expr): +class UnaryOp(Expr, WithTraceback): """ Unary operator. TODO: desc @@ -72,3 +85,4 @@ class Reducible(Expr): # @abstractmethod # def simplify(self) -> Expr: # """ Simplify the expression """ + diff --git a/mdpoly/operations/exp.py b/mdpoly/operations/exp.py index 2b17a07..ee90928 100644 --- a/mdpoly/operations/exp.py +++ b/mdpoly/operations/exp.py @@ -46,6 +46,13 @@ class PolyExp(BinaryOp, Reducible): if not isinstance(self.right.value, int): raise NotImplementedError("Cannot raise to non-integer powers (yet).") + if self.right.value == 0: + # FIXME: return constant + raise NotImplementedError + + elif self.right.value < 0: + raise AlgebraicError("Cannot raise to negative powers (yet).") + ntimes = self.right.value - 1 return reduce(PolyMul, (var for _ in range(ntimes)), var) |