Source code for memorax.utils.decorators
import functools
from typing import Callable
import jax
[docs]
def callback(function: Callable) -> Callable:
@functools.wraps(function)
def wrapper(*args, **kwargs) -> None:
jax.debug.callback(lambda args, kwargs: function(*args, **kwargs), args, kwargs)
return wrapper