Source code for memorax.buffers.episode_buffer

from typing import TYPE_CHECKING, Callable

from memorax.utils.typing import Array, Key

if TYPE_CHECKING:
    pass
else:
    pass

import chex
import jax
import jax.numpy as jnp
from flashbax import utils
from flashbax.buffers.trajectory_buffer import (
    Experience,
    TrajectoryBuffer,
    TrajectoryBufferSample,
    TrajectoryBufferState,
    make_trajectory_buffer,
)
from flashbax.utils import add_dim_to_args


[docs] def get_full_start_flags(experience: Experience) -> Array: return jnp.ones_like(experience.first.done)
[docs] def get_start_flags_from_done(experience: Experience) -> Array: return experience.first.done
def validate_sample_batch_size(sample_batch_size: int, max_length: int): if sample_batch_size > max_length: raise ValueError("sample_batch_size must be less than or equal to max_length") def validate_min_length(min_length: int, add_batch_size: int, max_length: int): used_min_length = min_length // add_batch_size max_length_time_axis = max_length // add_batch_size if used_min_length > max_length_time_axis: raise ValueError("min_length used is too large for the buffer size.") def validate_max_length_add_batch_size(max_length: int, add_batch_size: int): if max_length // add_batch_size < 2: raise ValueError( f"""max_length//add_batch_size must be greater than 2. It is currently {max_length}//{add_batch_size} = {max_length // add_batch_size}""" ) def validate_sample_sequence_length( sample_sequence_length: int, max_length: int, add_batch_size: int ): if sample_sequence_length > max_length // add_batch_size: raise ValueError( "sample_sequence_length must be <= max_length // add_batch_size" ) def validate_episode_buffer_args( max_length: int, min_length: int, sample_batch_size: int, sample_sequence_length: int, add_batch_size: int, ): validate_sample_batch_size(sample_batch_size, max_length) validate_min_length(min_length, add_batch_size, max_length) validate_max_length_add_batch_size(max_length, add_batch_size) validate_sample_sequence_length(sample_sequence_length, max_length, add_batch_size) def _valid_start_mask( state: TrajectoryBufferState[Experience], sample_sequence_length: int ) -> Array: _, max_length_time_axis = utils.get_tree_shape_prefix(state.experience, n_axes=2) time_indices = jnp.arange(max_length_time_axis) def _not_full() -> Array: last_valid = jnp.maximum(state.current_index - sample_sequence_length, -1) return (time_indices >= 0) & (time_indices <= last_valid) def _full() -> Array: return jnp.ones((max_length_time_axis,), dtype=bool) return jax.lax.cond(state.is_full, _full, _not_full)
[docs] def make_episode_buffer( max_length: int, min_length: int, sample_batch_size: int, sample_sequence_length: int, get_start_flags: Callable[[Experience], Array] = get_start_flags_from_done, add_sequences: bool = False, add_batch_size: int | None = None, min_length_time_axis: int | None = None, ) -> TrajectoryBuffer: if add_batch_size is None: add_batch_size = 1 add_batches = False else: add_batches = True validate_episode_buffer_args( max_length=max_length, min_length=min_length, sample_batch_size=sample_batch_size, sample_sequence_length=sample_sequence_length, add_batch_size=add_batch_size, ) buffer = make_trajectory_buffer( max_length_time_axis=max_length // add_batch_size, min_length_time_axis=max(min_length // add_batch_size, sample_sequence_length), add_batch_size=add_batch_size, sample_batch_size=sample_batch_size, sample_sequence_length=sample_sequence_length, period=1, max_size=None, ) add_fn = buffer.add if not add_batches: add_fn = add_dim_to_args( add_fn, axis=0, starting_arg_index=1, ending_arg_index=2 ) if not add_sequences: axis = 1 - int(not add_batches) add_fn = add_dim_to_args( add_fn, axis=axis, starting_arg_index=1, ending_arg_index=2 ) def sample_fn( state: TrajectoryBufferState[Experience], rng_key: Key ) -> TrajectoryBufferSample[Experience]: add_batch_size, max_length_time_axis = utils.get_tree_shape_prefix( state.experience, n_axes=2 ) start_flags = get_start_flags(state.experience) chex.assert_shape(start_flags, (add_batch_size, max_length_time_axis)) start_flags = start_flags.astype(jnp.float32) valid_mask = _valid_start_mask(state, sample_sequence_length).astype( jnp.float32 ) start_mask = start_flags * valid_mask[None, :] per_row = jnp.sum(start_mask, axis=1) total = jnp.sum(per_row) def _sample_from_starts(key: Key) -> tuple[Array, Array]: p_rows = per_row / total key_rows, key_starts = jax.random.split(key) rows = jax.random.choice( key_rows, a=add_batch_size, shape=(sample_batch_size,), p=p_rows, replace=True, ) row_start_probs = start_mask[rows] row_start_probs = row_start_probs / jnp.maximum( jnp.sum(row_start_probs, axis=1, keepdims=True), 1.0 ) keys = jax.random.split(key_starts, sample_batch_size) starts = jax.vmap( lambda k, p: jax.random.choice( k, a=max_length_time_axis, shape=(), p=p, replace=True ) )(keys, row_start_probs) return rows, starts def _fallback_beginning(key: Key) -> tuple[Array, Array]: rows = jax.random.randint(key, (sample_batch_size,), 0, add_batch_size) starts = jnp.zeros((sample_batch_size,), dtype=jnp.int32) return rows, starts rows, starts = jax.lax.cond( total > 0, _sample_from_starts, _fallback_beginning, rng_key ) time_idx = ( starts[:, None] + jnp.arange(sample_sequence_length) ) % max_length_time_axis experience = jax.tree.map( lambda x: x[rows[:, None], time_idx], state.experience ) return TrajectoryBufferSample(experience=experience) return buffer.replace(add=add_fn, sample=sample_fn)