Source code for memorax.utils.decorators

import functools

import jax


[docs] def callback(f): @functools.wraps(f) def wrapper(*args, **kwargs): return jax.debug.callback(f, *args, **kwargs) return wrapper