Skip to content

Instantly share code, notes, and snippets.

@masterdezign
Last active April 8, 2024 08:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masterdezign/47b3c6172dd1624bb9a7ef23cbc79c8c to your computer and use it in GitHub Desktop.
Save masterdezign/47b3c6172dd1624bb9a7ef23cbc79c8c to your computer and use it in GitHub Desktop.
Recurrent replay buffer
from copy import deepcopy
from typing import Any, Dict, Generator, List, Optional, Union
from typing import NamedTuple, Tuple
from gymnasium import spaces
import numpy as np
import torch as th
from stable_baselines3.common.buffers import BaseBuffer
from stable_baselines3.common.vec_env import VecNormalize
class RecurrentReplayBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
rewards: th.Tensor
dones: th.Tensor
mask: th.Tensor
class RecurrentReplayBuffer(BaseBuffer):
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
chunk_len: int = 120,
overlap: int = 40,
n_envs: int = 1,
**kwargs
):
"""
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param chunk_len: total number of timesteps to store in each chunk:
l + m, for example, l = 40 is the burn-in length and m = 80 is the
"useful" length of the chunk [1]
:param overlap: overlap length between stored chunks [1]
[1] Kapturowski, Steven, et al. "Recurrent experience replay in distributed
reinforcement learning." International Conference on Learning
Representations. 2019.
"""
# This might be something to rethink in the future:
# See
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/recurrent/buffers.py
# for the reference
assert n_envs == 1, "RecurrentReplayBuffer does not support multiple envs"
super().__init__(
buffer_size, observation_space, action_space, n_envs=n_envs, **kwargs
)
self.obs_dim = observation_space.shape[0]
self.act_dim = action_space.shape[0]
self.chunk_len = chunk_len
self.overlap = overlap
self.reset()
def reset(self) -> None:
"""
Reset the buffer.
"""
# Store chunks of episodes
# chunk_len + 1 because we store the final next observation in the chunk
self.o = np.zeros(
(self.buffer_size, self.chunk_len + 1, self.obs_dim), dtype=np.float32
)
self.a = np.zeros(
(self.buffer_size, self.chunk_len, self.act_dim), dtype=np.float32
)
self.r = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32)
self.d = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32)
# Mask: Valid step = 1, no record = 0
self.m = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32)
# self.pos (from parent class) is the position of the episode chunk in
# the buffer (a "row counter").
# self.time_pos is the position of the timestep in the chunk (a "column
# counter").
self.time_pos = 0
super().reset()
def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
) -> None:
# Copy to avoid modification by reference
self.o[self.pos, self.time_pos] = np.array(obs).copy()
self.o[self.pos, self.time_pos + 1] = np.array(next_obs).copy()
self.a[self.pos, self.time_pos] = np.array(action).copy()
self.r[self.pos, self.time_pos] = np.array(reward).copy()
self.d[self.pos, self.time_pos] = np.array(done).copy()
self.m[self.pos, self.time_pos] = 1
# Update the time position in the chunk
self.time_pos += 1
# Chunk just finished
end_of_chunk = self.time_pos == self.chunk_len
# Special cases:
# If the chunk is complete or the episode is done:
# - New chunk position in the buffer (row counter)
# - Reset the time position in the chunk (column counter)
if end_of_chunk or done: # n_envs == 1
# Check whether the buffer is going to be full
if self.pos == self.buffer_size - 1:
self.full = True
# Start a new chunk by updating the position in the buffer
self.pos = (self.pos + 1) % self.buffer_size
# Overlap handling on the end of chunk: Copy the last `overlap`
# timesteps to the beginning of the next chunk.
# If its done by the end of chunk, nothing to do.
if end_of_chunk and not done:
self.o[self.pos, : self.overlap + 1] = self.o[
self.pos - 1, -(self.overlap + 1) :
]
self.a[self.pos, : self.overlap] = self.a[self.pos - 1, -self.overlap :]
self.r[self.pos, : self.overlap] = self.r[self.pos - 1, -self.overlap :]
self.d[self.pos, : self.overlap] = self.d[self.pos - 1, -self.overlap :]
# Fill the mask with 1 for the valid steps
self.m[self.pos, : self.overlap] = 1
self.time_pos = self.overlap
if done: # n_envs == 1
# Move time position to the beginning of the chunk
self.time_pos = 0
def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> RecurrentReplayBufferSamples:
"""
:param batch_inds:
:param env:
:return: A batch of chunks of episodes
"""
o = self.o[batch_inds]
a = self.a[batch_inds]
r = self.r[batch_inds]
d = self.d[batch_inds]
m = self.m[batch_inds]
o = self._normalize_obs(o, env)
data = (o, a, r, d, m)
return RecurrentReplayBufferSamples(*tuple(map(self.to_torch, data)))
import gymnasium as gym
import numpy as np
from buffers import RecurrentReplayBuffer
def random_sample(done=False, prev_obs=None):
return (
prev_obs if prev_obs is not None else np.random.rand(3).astype(np.float32),
np.random.rand(3).astype(np.float32),
np.random.rand(1).astype(np.float32),
np.random.rand(),
done,
)
def _init_buffer(buffer_size=8, elements=0, chunk_len=4, overlap=1, env=None):
if env is not None:
buffer = RecurrentReplayBuffer(
buffer_size,
env.observation_space,
env.action_space,
chunk_len=chunk_len,
overlap=overlap,
)
o, _ = env.reset()
else:
assert False, "TODO: Use mock observation and action spaces"
for _ in range(elements):
if env is not None:
a = env.action_space.sample()
o2, r, term, trunc, _ = env.step(a)
# Fill the buffer with random samples
buffer.add(o, o2, a, r, term)
o = o2
else:
buffer.add(*random_sample())
return buffer
def test_add():
env = gym.make("Pendulum-v1")
buffer_size = 8
chunk_len = 5
overlap = 3
buffer = _init_buffer(
buffer_size=buffer_size, chunk_len=chunk_len, overlap=overlap, env=env
)
assert np.abs(buffer.o).sum() == 0, "Buffer should be empty"
assert np.abs(buffer.m).sum() == 0, "Buffer should be empty"
assert buffer.pos == 0, "Buffer should be empty"
assert buffer.time_pos == 0, "Buffer should be empty"
sa = random_sample()
buffer.add(*sa)
assert buffer.pos == 0, "Position should have not changed"
assert buffer.time_pos == 1, "Time position should have increased"
assert np.allclose(buffer.o[0][0], sa[0]), "0 Observations should be recorded"
assert np.allclose(buffer.o[0][1], sa[1]), "0 Next observations should be recorded"
assert np.allclose(buffer.a[0][0], sa[2]), "0 Actions should be recorded"
assert np.allclose(buffer.r[0][0], sa[3]), "0 Rewards should be recorded"
assert np.allclose(buffer.d[0][0], False), "0 Dones should be recorded"
assert buffer.m[0][0] == 1, "Mask should be updated"
assert np.abs(buffer.m).sum() == 1, "Mask should be updated"
sa = random_sample(prev_obs=sa[1])
buffer.add(*sa)
assert buffer.pos == 0, "Position should have not changed"
assert buffer.time_pos == 2, "Time position should have increased"
assert np.allclose(buffer.o[0][1], sa[0]), "1 Observations should be recorded"
assert np.allclose(buffer.o[0][2], sa[1]), "1 Next observations should be recorded"
assert np.allclose(buffer.a[0][1], sa[2]), "1 Actions should be recorded"
assert np.allclose(buffer.r[0][1], sa[3]), "1 Rewards should be recorded"
assert np.allclose(buffer.d[0][1], False), "1 Dones should be recorded"
assert buffer.m[0][1] == 1, "Mask should be updated"
assert np.abs(buffer.m).sum() == 2, "Mask should be updated"
# Simulate done episode
sa = random_sample(done=True, prev_obs=sa[1])
buffer.add(*sa)
assert buffer.pos == 1, "New chunk should have started"
assert buffer.time_pos == 0, "Time position should have been reset"
assert np.allclose(buffer.o[0][2], sa[0]), "2 Observations should be recorded"
assert np.allclose(buffer.o[0][3], sa[1]), "2 Next observations should be recorded"
assert np.allclose(buffer.a[0][2], sa[2]), "2 Actions should be recorded"
assert np.allclose(buffer.r[0][2], sa[3]), "2 Rewards should be recorded"
assert np.allclose(buffer.d[0][2], True), "2 Dones should be recorded"
assert buffer.m[0][2] == 1, "Mask should be updated"
assert np.abs(buffer.m).sum() == 3, "Mask should be updated"
print("New chunk:\n", buffer.o[1])
# Test automatic chunking
for i in range(chunk_len - 1):
sa = random_sample(prev_obs=sa[1])
buffer.add(*sa)
assert buffer.pos == 1, "Position should have not changed"
assert buffer.time_pos == i + 1, "Time position should have increased"
assert np.allclose(buffer.o[1][i], sa[0]), "Observations should be recorded"
assert np.allclose(
buffer.o[1][i + 1], sa[1]
), "Next observations should be recorded"
assert np.allclose(buffer.a[1][i], sa[2]), "Actions should be recorded"
assert np.allclose(buffer.r[1][i], sa[3]), "Rewards should be recorded"
assert np.allclose(buffer.d[1][i], False), "Dones should be recorded"
assert buffer.m[1][i] == 1, "Mask should be updated"
assert np.abs(buffer.m).sum() == i + 4, "Mask should be updated"
print("Current chunk:\n", buffer.o[1])
# Here we should start a new chunk
sa2 = random_sample(prev_obs=sa[1])
buffer.add(*sa2)
print("Prev obs", sa2[0])
print("New obs", sa2[1])
print("Current chunk:\n", buffer.o[1])
assert buffer.full == False, "Buffer should not be full"
assert buffer.pos == 2, "New chunk should have started"
assert buffer.time_pos == overlap, "Time position should have been moved to overlap"
assert np.allclose(buffer.o[1][chunk_len - 1], sa2[0]), "End of previous chunk"
assert np.allclose(
buffer.o[1][chunk_len], sa2[1]
), "End of previous chunk - new obs"
assert buffer.m[1][chunk_len - 1] == 1, "Mask should be updated"
print("New chunk:\n", buffer.o[2])
assert np.allclose(
buffer.o[2][1], sa[0]
), "Overlap: Old observations should have been preserved"
assert np.allclose(
buffer.o[2][2], sa2[0]
), "Overlap: Current observations should be recorded"
assert np.allclose(
buffer.o[2][3], sa2[1]
), "Overlap: Next observations should be recorded"
assert buffer.m[2][0] == 1, "Overlap: Mask should be updated"
assert buffer.m[2][1] == 1, "Overlap: Mask should be updated"
assert buffer.m[2][2] == 1, "Overlap: Mask should be updated"
assert buffer.m[2][3] == 0, "Mask should remain 0"
sa2 = random_sample(prev_obs=sa2[1])
buffer.add(*sa2)
print("Current chunk:\n", buffer.o[2])
assert buffer.pos == 2, "Position should remain the same"
assert buffer.time_pos == overlap + 1, "Time position should have been increased"
assert np.allclose(
buffer.o[2][3], sa2[0]
), "Overlap: Current observations should be recorded"
assert np.allclose(
buffer.o[2][4], sa2[1]
), "Overlap: Next observations should be recorded"
# Edge case test: end of chunk and done
sa2 = random_sample(prev_obs=sa2[1], done=True)
buffer.add(*sa2)
print("Current chunk:\n", buffer.o[2])
assert buffer.pos == 3, "Position should have increased"
assert buffer.time_pos == 0, "Time position should have been reset"
assert buffer.full == False, "Buffer should not be full"
print("New empty chunk:\n", buffer.o[3])
# Fill the buffer until it's full
for test_pos in range(4, buffer_size):
print("test_pos", test_pos)
sa2 = random_sample(done=True)
buffer.add(*sa2)
assert np.allclose(
buffer.o[test_pos - 1][0], sa2[0]
), "Overlap: Current observations should be recorded"
assert np.allclose(
buffer.o[test_pos - 1][1], sa2[1]
), "Overlap: Next observations should be recorded"
assert buffer.pos == test_pos, "Position should have increased"
assert buffer.time_pos == 0, "Time position should have been reset"
assert buffer.full == False, "Buffer should not be full"
sa2 = random_sample(done=True)
buffer.add(*sa2)
assert buffer.full == True, "Buffer should be full"
def test_sample():
env = gym.make("Pendulum-v1")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
buffer_size = 200
chunk_len = 40
overlap = 5
buffer = _init_buffer(
elements=100,
buffer_size=buffer_size, chunk_len=chunk_len, overlap=overlap, env=env
)
batch = buffer.sample(32)
print("observations.shape", batch.observations.shape)
print("rewards.shape", batch.rewards.shape)
assert len(batch.observations) == 32, "Batch should have 32 elements"
assert batch.observations.shape == (32, chunk_len + 1, obs_dim), "Observations shape should be (32, 40 + 1, obs_dim)"
assert batch.actions.shape == (32, chunk_len, act_dim), "Actions shape should be (32, 40, act_dim)"
assert batch.rewards.shape == (32, chunk_len, 1), "Rewards shape should be (32, 40, 1)"
assert batch.dones.shape == (32, chunk_len, 1), "Dones shape should be (32, 40, 1)"
assert batch.mask.shape == (32, chunk_len, 1), "Mask shape should be (32, 40, 1)"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment