Source code for memorax.buffers.prioritised_episode_buffer

"""Prioritised episode buffer combining episode-aware sampling with priority-based replay.

This buffer extends the episode buffer with prioritized experience replay (PER) as described
in https://arxiv.org/abs/1511.05952. It combines:
1. Episode-aware sampling: Only samples from valid episode start positions
2. Priority-weighted sampling: Samples proportionally to TD-error priorities
3. Importance sampling weights: For correcting the bias introduced by non-uniform sampling
"""

import functools
from typing import TYPE_CHECKING, Callable, Generic

if TYPE_CHECKING:
    from dataclasses import dataclass
else:
    from chex import dataclass

import chex
import jax
import jax.numpy as jnp
from memorax.utils.typing import Key
from flashbax import utils
from flashbax.buffers import sum_tree
from flashbax.buffers.prioritised_trajectory_buffer import (
    SET_BATCH_FN,
    PrioritisedTrajectoryBuffer,
    PrioritisedTrajectoryBufferSample,
    PrioritisedTrajectoryBufferState,
    Probabilities,
    prioritised_init,
    set_priorities,
    validate_device,
    validate_priority_exponent,
)
from flashbax.buffers.sum_tree import SumTreeState
from flashbax.buffers.trajectory_buffer import Experience, can_sample
from flashbax.utils import add_dim_to_args
from jax import Array

from .episode_buffer import get_start_flags_from_done, validate_episode_buffer_args

Indices = Array


[docs] @dataclass(frozen=True) class PrioritisedEpisodeBufferSample( PrioritisedTrajectoryBufferSample, Generic[Experience] ): """Sample from prioritised episode buffer with priority information. Attributes: experience: The sampled experience trajectories. indices: Indices corresponding to the sampled sequences (for priority updates). probabilities: Sampling probabilities of the sampled sequences (for importance weights). """ pass
[docs] def compute_importance_weights( probabilities: Probabilities, buffer_size: int, beta: float = 0.4, ) -> Array: """Compute importance sampling weights for prioritized experience replay. The importance sampling weights correct for the bias introduced by non-uniform sampling. Weights are normalized by the maximum weight for stability. w_i = (N * P(i))^(-beta) / max_j(w_j) Args: probabilities: Sampling probabilities from the buffer sample. buffer_size: Current number of valid items in the buffer. beta: Importance sampling exponent. Should be annealed from initial value (e.g., 0.4) to 1.0 over the course of training. Returns: Normalized importance sampling weights with the same shape as probabilities. """ safe_probs = jnp.maximum(probabilities, 1e-10) weights = (buffer_size * safe_probs) ** (-beta) weights = weights / jnp.maximum(weights.max(), 1e-10) return weights
def _valid_start_mask( state: PrioritisedTrajectoryBufferState[Experience], sample_sequence_length: int ) -> Array: """Get mask of valid start positions based on buffer fill state. Args: state: The buffer state. sample_sequence_length: Length of sequences to sample. Returns: Boolean mask of shape [max_length_time_axis] indicating valid start positions. """ _, max_length_time_axis = utils.get_tree_shape_prefix(state.experience, n_axes=2) time_indices = jnp.arange(max_length_time_axis) def _not_full() -> Array: last_valid = jnp.maximum(state.current_index - sample_sequence_length, -1) return (time_indices >= 0) & (time_indices <= last_valid) def _full() -> Array: return jnp.ones((max_length_time_axis,), dtype=bool) return jax.lax.cond(state.is_full, _full, _not_full) def _get_priorities_for_positions( sum_tree_state: SumTreeState, add_batch_size: int, max_length_time_axis: int, ) -> Array: """Get priorities for all positions in the buffer. Since we use period=1, each timestep position maps to one item in the sum tree. Args: sum_tree_state: The sum tree state containing priorities. add_batch_size: Number of parallel environments (rows in buffer). max_length_time_axis: Length of time axis in buffer. Returns: Array of shape [add_batch_size, max_length_time_axis] with priorities. """ num_items_per_row = max_length_time_axis total_items = add_batch_size * num_items_per_row item_indices = jnp.arange(total_items) priorities = sum_tree.get(sum_tree_state, item_indices) priorities = priorities.reshape(add_batch_size, max_length_time_axis) return priorities def prioritised_episode_sample( state: PrioritisedTrajectoryBufferState[Experience], rng_key: Key, sample_batch_size: int, sample_sequence_length: int, get_start_flags: Callable[[Experience], Array], ) -> PrioritisedEpisodeBufferSample[Experience]: """Sample episodes weighted by priority, respecting episode boundaries. This combines episode-aware sampling with priority-weighted sampling: 1. Identify valid episode start positions using start_flags 2. Mask priorities to zero for non-start positions 3. Sample proportionally to masked priorities 4. Return samples with indices and probabilities for importance sampling Args: state: The prioritised buffer state. rng_key: Random key for sampling. sample_batch_size: Number of sequences to sample. sample_sequence_length: Length of each sampled sequence. get_start_flags: Function to extract episode start flags from experience. Returns: PrioritisedEpisodeBufferSample with experience, indices, and probabilities. """ add_batch_size, max_length_time_axis = utils.get_tree_shape_prefix( state.experience, n_axes=2 ) start_flags = get_start_flags(state.experience) chex.assert_shape(start_flags, (add_batch_size, max_length_time_axis)) start_flags = start_flags.astype(jnp.float32) valid_mask = _valid_start_mask(state, sample_sequence_length).astype(jnp.float32) combined_mask = start_flags * valid_mask[None, :] priorities = _get_priorities_for_positions( state.sum_tree_state, add_batch_size, max_length_time_axis ) masked_priorities = priorities * combined_mask flat_priorities = masked_priorities.flatten() total_priority = jnp.sum(flat_priorities) def _sample_with_priorities(key: Key) -> tuple[Array, Array, Array, Array]: probs = flat_priorities / jnp.maximum(total_priority, 1e-10) flat_indices = jax.random.choice( key, a=add_batch_size * max_length_time_axis, shape=(sample_batch_size,), p=probs, replace=True, ) rows = flat_indices // max_length_time_axis starts = flat_indices % max_length_time_axis selected_probs = probs[flat_indices] return rows, starts, flat_indices, selected_probs def _fallback_uniform(key: Key) -> tuple[Array, Array, Array, Array]: rows = jax.random.randint(key, (sample_batch_size,), 0, add_batch_size) starts = jnp.zeros((sample_batch_size,), dtype=jnp.int32) flat_indices = rows * max_length_time_axis + starts uniform_prob = 1.0 / (add_batch_size * max_length_time_axis) selected_probs = jnp.full((sample_batch_size,), uniform_prob) return rows, starts, flat_indices, selected_probs rows, starts, flat_indices, selected_probs = jax.lax.cond( total_priority > 0, _sample_with_priorities, _fallback_uniform, rng_key ) time_idx = ( starts[:, None] + jnp.arange(sample_sequence_length) ) % max_length_time_axis experience = jax.tree.map(lambda x: x[rows[:, None], time_idx], state.experience) item_indices = flat_indices return PrioritisedEpisodeBufferSample( experience=experience, indices=item_indices, probabilities=selected_probs, ) def prioritised_episode_add( state: PrioritisedTrajectoryBufferState[Experience], batch: Experience, sample_sequence_length: int, device: str, ) -> PrioritisedTrajectoryBufferState[Experience]: """Add experience to the prioritised episode buffer. New items are assigned the maximum recorded priority. Items that become invalid (overwritten or broken by the circular buffer) have their priority set to zero. Args: state: Current buffer state. batch: Batch of experience to add with shape [add_batch_size, seq_len, ...]. sample_sequence_length: Length of sequences that will be sampled. device: Device type for optimized operations ("cpu", "gpu", or "tpu"). Returns: Updated buffer state with new experience and priorities. """ chex.assert_tree_shape_prefix(batch, utils.get_tree_shape_prefix(state.experience)) chex.assert_trees_all_equal_dtypes(batch, state.experience) add_sequence_length = utils.get_tree_shape_prefix(batch, n_axes=2)[1] add_batch_size, max_length_time_axis = utils.get_tree_shape_prefix( state.experience, n_axes=2 ) data_indices = ( jnp.arange(add_sequence_length) + state.current_index ) % max_length_time_axis new_experience = jax.tree.map( lambda exp_field, batch_field: exp_field.at[:, data_indices].set(batch_field), state.experience, batch, ) new_item_time_indices = ( jnp.arange(add_sequence_length) + state.current_index ) % max_length_time_axis row_offsets = jnp.arange(add_batch_size)[:, None] * max_length_time_axis newly_valid_items = (new_item_time_indices[None, :] + row_offsets).flatten() new_priorities = jnp.full( newly_valid_items.shape, state.sum_tree_state.max_recorded_priority ) invalid_time_start = ( state.current_index - (sample_sequence_length - 1) + max_length_time_axis ) % max_length_time_axis invalid_time_indices = ( invalid_time_start + jnp.arange(sample_sequence_length - 1) ) % max_length_time_axis newly_invalid_items = (invalid_time_indices[None, :] + row_offsets).flatten() invalid_priorities = jnp.zeros(newly_invalid_items.shape) new_sum_tree_state = SET_BATCH_FN[device]( state.sum_tree_state, newly_invalid_items, invalid_priorities, ) new_sum_tree_state = SET_BATCH_FN[device]( new_sum_tree_state, newly_valid_items, new_priorities, ) new_current_index = state.current_index + add_sequence_length new_running_index = state.running_index + add_sequence_length new_is_full = state.is_full | (new_current_index >= max_length_time_axis) new_current_index = new_current_index % max_length_time_axis return state.replace( experience=new_experience, current_index=new_current_index, is_full=new_is_full, running_index=new_running_index, sum_tree_state=new_sum_tree_state, )
[docs] def make_prioritised_episode_buffer( max_length: int, min_length: int, sample_batch_size: int, sample_sequence_length: int, get_start_flags: Callable[[Experience], Array] = get_start_flags_from_done, add_sequences: bool = False, add_batch_size: int | None = None, priority_exponent: float = 0.6, device: str = "cpu", ) -> PrioritisedTrajectoryBuffer: """Create a prioritised episode buffer. This buffer combines episode-aware sampling with prioritized experience replay: - Only samples from valid episode start positions (identified by get_start_flags) - Weights sampling by TD-error priorities - Returns indices and probabilities for importance sampling weight computation Args: max_length: Maximum total capacity of the buffer in timesteps. min_length: Minimum number of timesteps before sampling is allowed. sample_batch_size: Number of sequences to sample per batch. sample_sequence_length: Length of each sampled sequence. get_start_flags: Function that takes experience and returns boolean array of shape [batch, time] indicating episode start positions. Defaults to get_start_flags_from_done which uses prev_done. add_sequences: If True, expect sequences when adding. If False, expect single timesteps. add_batch_size: Batch size of experience added to buffer. If None, expects unbatched experience. priority_exponent: Priority exponent (alpha in PER paper). Controls how much prioritization is used. 0 = uniform sampling, 1 = full prioritization. device: Device for optimized operations ("cpu", "gpu", or "tpu"). Returns: PrioritisedTrajectoryBuffer with episode-aware priority sampling. Example: >>> buffer = make_prioritised_episode_buffer( ... max_length=100_000, ... min_length=1000, ... sample_batch_size=32, ... sample_sequence_length=16, ... add_batch_size=8, ... priority_exponent=0.6, ... ) >>> state = buffer.init(sample_transition) >>> state = buffer.add(state, transitions) >>> sample = buffer.sample(state, rng_key) >>> weights = compute_importance_weights(sample.probabilities, buffer_size, beta=0.4) >>> # After computing TD-errors: >>> state = buffer.set_priorities(state, sample.indices, jnp.abs(td_errors) + 1e-6) """ if add_batch_size is None: add_batch_size = 1 add_batches = False else: add_batches = True validate_episode_buffer_args( max_length=max_length, min_length=min_length, sample_batch_size=sample_batch_size, sample_sequence_length=sample_sequence_length, add_batch_size=add_batch_size, ) validate_priority_exponent(priority_exponent) if not validate_device(device): device = "cpu" max_length_time_axis = max_length // add_batch_size min_length_time_axis = max(min_length // add_batch_size, sample_sequence_length) period = 1 init_fn = functools.partial( prioritised_init, add_batch_size=add_batch_size, max_length_time_axis=max_length_time_axis, period=period, ) add_fn = functools.partial( prioritised_episode_add, sample_sequence_length=sample_sequence_length, device=device, ) if not add_batches: add_fn = add_dim_to_args( add_fn, axis=0, starting_arg_index=1, ending_arg_index=2 ) if not add_sequences: axis = 1 - int(not add_batches) add_fn = add_dim_to_args( add_fn, axis=axis, starting_arg_index=1, ending_arg_index=2 ) sample_fn = functools.partial( prioritised_episode_sample, sample_batch_size=sample_batch_size, sample_sequence_length=sample_sequence_length, get_start_flags=get_start_flags, ) can_sample_fn = functools.partial( can_sample, min_length_time_axis=min_length_time_axis ) set_priorities_fn = functools.partial( set_priorities, priority_exponent=priority_exponent, device=device ) return PrioritisedTrajectoryBuffer( init=init_fn, add=add_fn, sample=sample_fn, can_sample=can_sample_fn, set_priorities=set_priorities_fn, )