memorax.buffers#

Episode-aware replay buffers for off-policy algorithms.

Episode Buffer#

memorax.buffers.make_episode_buffer(max_length, min_length, sample_batch_size, sample_sequence_length, get_start_flags=<function get_start_flags_from_done>, add_sequences=False, add_batch_size=None, min_length_time_axis=None)[source]#
Return type:

TrajectoryBuffer

Parameters:
  • max_length (int)

  • min_length (int)

  • sample_batch_size (int)

  • sample_sequence_length (int)

  • get_start_flags (Callable[[Experience], Array])

  • add_sequences (bool)

  • add_batch_size (int | None)

  • min_length_time_axis (int | None)

memorax.buffers.get_full_start_flags(experience)[source]#
Return type:

Array

Parameters:

experience (Experience)

memorax.buffers.get_start_flags_from_done(experience)[source]#
Return type:

Array

Parameters:

experience (Experience)

Prioritized Episode Buffer#

memorax.buffers.make_prioritised_episode_buffer(max_length, min_length, sample_batch_size, sample_sequence_length, get_start_flags=<function get_start_flags_from_done>, add_sequences=False, add_batch_size=None, priority_exponent=0.6, device='cpu')[source]#

Create a prioritised episode buffer.

This buffer combines episode-aware sampling with prioritized experience replay: - Only samples from valid episode start positions (identified by get_start_flags) - Weights sampling by TD-error priorities - Returns indices and probabilities for importance sampling weight computation

Parameters:
  • max_length (int) – Maximum total capacity of the buffer in timesteps.

  • min_length (int) – Minimum number of timesteps before sampling is allowed.

  • sample_batch_size (int) – Number of sequences to sample per batch.

  • sample_sequence_length (int) – Length of each sampled sequence.

  • get_start_flags (Callable[[TypeVar(Experience, bound= Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]])], Array]) – Function that takes experience and returns boolean array of shape [batch, time] indicating episode start positions. Defaults to get_start_flags_from_done which uses prev_done.

  • add_sequences (bool) – If True, expect sequences when adding. If False, expect single timesteps.

  • add_batch_size (int | None) – Batch size of experience added to buffer. If None, expects unbatched experience.

  • priority_exponent (float) – Priority exponent (alpha in PER paper). Controls how much prioritization is used. 0 = uniform sampling, 1 = full prioritization.

  • device (str) – Device for optimized operations (“cpu”, “gpu”, or “tpu”).

Return type:

PrioritisedTrajectoryBuffer

Returns:

PrioritisedTrajectoryBuffer with episode-aware priority sampling.

Example

>>> buffer = make_prioritised_episode_buffer(
...     max_length=100_000,
...     min_length=1000,
...     sample_batch_size=32,
...     sample_sequence_length=16,
...     add_batch_size=8,
...     priority_exponent=0.6,
... )
>>> state = buffer.init(sample_transition)
>>> state = buffer.add(state, transitions)
>>> sample = buffer.sample(state, rng_key)
>>> weights = compute_importance_weights(sample.probabilities, buffer_size, beta=0.4)
>>> # After computing TD-errors:
>>> state = buffer.set_priorities(state, sample.indices, jnp.abs(td_errors) + 1e-6)
memorax.buffers.compute_importance_weights(probabilities, buffer_size, beta=0.4)[source]#

Compute importance sampling weights for prioritized experience replay.

The importance sampling weights correct for the bias introduced by non-uniform sampling. Weights are normalized by the maximum weight for stability.

w_i = (N * P(i))^(-beta) / max_j(w_j)

Parameters:
  • probabilities (Array) – Sampling probabilities from the buffer sample.

  • buffer_size (int) – Current number of valid items in the buffer.

  • beta (float) – Importance sampling exponent. Should be annealed from initial value (e.g., 0.4) to 1.0 over the course of training.

Return type:

Array

Returns:

Normalized importance sampling weights with the same shape as probabilities.

class memorax.buffers.PrioritisedEpisodeBufferSample[source]#

Bases: PrioritisedTrajectoryBufferSample, Generic[Experience]

Sample from prioritised episode buffer with priority information.

experience#

The sampled experience trajectories.

indices#

Indices corresponding to the sampled sequences (for priority updates).

probabilities#

Sampling probabilities of the sampled sequences (for importance weights).

__init__(experience, indices, probabilities)#
Parameters:
  • experience (Experience)

  • indices (Array)

  • probabilities (Array)

Return type:

None

from_tuple()#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
replace(**kwargs)#
to_tuple()#
values() an object providing a view on D's values#