Source code for memorax.utils.decorators
import functools
import jax
[docs]
def callback(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
return jax.debug.callback(f, *args, **kwargs)
return wrapper
import functools
import jax
[docs]
def callback(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
return jax.debug.callback(f, *args, **kwargs)
return wrapper