memorax.buffers#
Episode-aware replay buffers for off-policy algorithms.
Episode Buffer#
- memorax.buffers.make_episode_buffer(max_length, min_length, sample_batch_size, sample_sequence_length, get_start_flags=<function get_start_flags_from_done>, add_sequences=False, add_batch_size=None, min_length_time_axis=None)[source]#
Prioritized Episode Buffer#
- memorax.buffers.make_prioritised_episode_buffer(max_length, min_length, sample_batch_size, sample_sequence_length, get_start_flags=<function get_start_flags_from_done>, add_sequences=False, add_batch_size=None, priority_exponent=0.6, device='cpu')[source]#
Create a prioritised episode buffer.
This buffer combines episode-aware sampling with prioritized experience replay: - Only samples from valid episode start positions (identified by get_start_flags) - Weights sampling by TD-error priorities - Returns indices and probabilities for importance sampling weight computation
- Parameters:
max_length (
int) – Maximum total capacity of the buffer in timesteps.min_length (
int) – Minimum number of timesteps before sampling is allowed.sample_batch_size (
int) – Number of sequences to sample per batch.sample_sequence_length (
int) – Length of each sampled sequence.get_start_flags (
Callable[[TypeVar(Experience, bound=Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]])],Array]) – Function that takes experience and returns boolean array of shape [batch, time] indicating episode start positions. Defaults to get_start_flags_from_done which uses prev_done.add_sequences (
bool) – If True, expect sequences when adding. If False, expect single timesteps.add_batch_size (
int|None) – Batch size of experience added to buffer. If None, expects unbatched experience.priority_exponent (
float) – Priority exponent (alpha in PER paper). Controls how much prioritization is used. 0 = uniform sampling, 1 = full prioritization.device (
str) – Device for optimized operations (“cpu”, “gpu”, or “tpu”).
- Return type:
PrioritisedTrajectoryBuffer- Returns:
PrioritisedTrajectoryBuffer with episode-aware priority sampling.
Example
>>> buffer = make_prioritised_episode_buffer( ... max_length=100_000, ... min_length=1000, ... sample_batch_size=32, ... sample_sequence_length=16, ... add_batch_size=8, ... priority_exponent=0.6, ... ) >>> state = buffer.init(sample_transition) >>> state = buffer.add(state, transitions) >>> sample = buffer.sample(state, rng_key) >>> weights = compute_importance_weights(sample.probabilities, buffer_size, beta=0.4) >>> # After computing TD-errors: >>> state = buffer.set_priorities(state, sample.indices, jnp.abs(td_errors) + 1e-6)
- memorax.buffers.compute_importance_weights(probabilities, buffer_size, beta=0.4)[source]#
Compute importance sampling weights for prioritized experience replay.
The importance sampling weights correct for the bias introduced by non-uniform sampling. Weights are normalized by the maximum weight for stability.
w_i = (N * P(i))^(-beta) / max_j(w_j)
- Parameters:
- Return type:
- Returns:
Normalized importance sampling weights with the same shape as probabilities.
- class memorax.buffers.PrioritisedEpisodeBufferSample[source]#
Bases:
PrioritisedTrajectoryBufferSample,Generic[Experience]Sample from prioritised episode buffer with priority information.
- experience#
The sampled experience trajectories.
- indices#
Indices corresponding to the sampled sequences (for priority updates).
- probabilities#
Sampling probabilities of the sampled sequences (for importance weights).
- __init__(experience, indices, probabilities)#
- from_tuple()#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- replace(**kwargs)#
- to_tuple()#
- values() an object providing a view on D's values#