memorax.algorithms.R2D2#
- class memorax.algorithms.R2D2[source]#
Bases:
objectR2D2(cfg: memorax.algorithms.r2d2.R2D2Config, env: gymnax.environments.environment.Environment, env_params: gymnax.environments.environment.EnvParams, q_network: flax.linen.module.Module, optimizer: optax._src.base.GradientTransformation, buffer: flashbax.buffers.trajectory_buffer.TrajectoryBuffer, epsilon_schedule: collections.abc.Callable[[typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], beta_schedule: collections.abc.Callable[[typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]])
- cfg: R2D2Config#
- env: Environment#
- env_params: EnvParams#
- q_network: Module#
- optimizer: GradientTransformation#
- buffer: TrajectoryBuffer#
- epsilon_schedule: Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex]#
- beta_schedule: Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex]#
- __init__(cfg, env, env_params, q_network, optimizer, buffer, epsilon_schedule, beta_schedule)#
- Parameters:
cfg (R2D2Config)
env (Environment)
env_params (EnvParams)
q_network (Module)
optimizer (GradientTransformation)
buffer (TrajectoryBuffer)
epsilon_schedule (Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex])
beta_schedule (Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex])
- Return type:
None