-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Description
It would be great if jax provided a simple way to raise errors inside of jitted functions and generate meaningful error-messages.
I understand the reasoning behind checkify and the functional approach without side-effects
However, there are situations where this becomes very difficult to use in practice, since there are scenarios where subsequent code-execution must really be stopped when a condition is not met:
- Consider a function with a while-loop that may enter an infinite loop when an assumption is not satisfied
- Consider calls to the FFI where out-of-bounds memory may be accessed if a condition is not met
I understand that it might be possible to solve these scenarios by guarding these behind conditionals, but this may be really impractical if the code that causes the invalid data lies very far apart from the code that might cause undefined behavior. Strictly sticking to the functional approach would require consideration of all possible invalid code paths which may be a nightmare in complex projects.
I have found that it is possible with an io_callback that lies behind a jax.lax.cond to reliably raise an error before other code gets executed -- at negligible extra-costs. One can guarantee execution of the error by making subsequent code depend on the return value. However, it is tricky to generate a meaningful error message with this, since the run-time call stack is rather irrelevant and the trace-time call stack is not included by default. I put my current best effort of getting a good error message and a simple interface below. It would be really useful if jax provided a raise-if function like this by default -- ideally with the uninteresting run-time error stack removed and the error message simplified.
Usage:
from error_helpers import raise_if
def myfunc(i, n):
i = i + raise_if(i >= n, "i={i} cannot be bigger than n={n} for this to work", i=i, n=n)
return n - i
a = jax.jit(myfunc)(5, 10) # runs fine
b = jax.jit(myfunc)(10, 5) # will raise an error Before "i" is used in further computationError message
[...]
======== Relevant Error Message =========
ValueError: /home/jens/repos/jz-tree/notebooks/tmp/error_example.py:6: i=10 cannot be bigger than n=5 for this to work
Trace (tracing time, most recent call last):
[...]
/home/jens/repos/jz-tree/notebooks/tmp/error_example.py:6 myfunc
=========================================
[...]