memorax.algorithms.DQN#

class memorax.algorithms.DQN[source]#

Bases: object

DQN(cfg: memorax.algorithms.dqn.DQNConfig, env: gymnax.environments.environment.Environment, env_params: gymnax.environments.environment.EnvParams, q_network: flax.linen.module.Module, optimizer: optax._src.base.GradientTransformation, buffer: flashbax.buffers.trajectory_buffer.TrajectoryBuffer, epsilon_schedule: collections.abc.Callable[[typing.Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], typing.Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]])

cfg: DQNConfig#
env: Environment#
env_params: EnvParams#
q_network: Module#
optimizer: GradientTransformation#
buffer: TrajectoryBuffer#
epsilon_schedule: Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int]#
init(key)[source]#
warmup(key, state, num_steps)[source]#
Return type:

tuple[Array, DQNState]

Parameters:
train(key, state, num_steps)[source]#
Parameters:
evaluate(key, state, num_steps)[source]#
Return type:

tuple[Array, dict]

Parameters:
__init__(cfg, env, env_params, q_network, optimizer, buffer, epsilon_schedule)#
Parameters:
Return type:

None

replace(**updates)#

Returns a new object replacing the specified fields with new values.