Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Created March 26, 2020 08:14
Show Gist options
  • Save djbyrne/45b6cbc620c8acbef259c7d519bca80f to your computer and use it in GitHub Desktop.
Save djbyrne/45b6cbc620c8acbef259c7d519bca80f to your computer and use it in GitHub Desktop.
Basic Replay Buffer
# Named tuple for storing experience steps gathered in training
Experience = collections.namedtuple(
'Experience', field_names=['state', 'action', 'reward',
'done', 'new_state'])
class ReplayBuffer:
"""
Replay Buffer for storing past experiences allowing the agent to learn from them
Args:
capacity: size of the buffer
"""
def __init__(self, capacity: int) -> None:
self.buffer = collections.deque(maxlen=capacity)
def __len__(self) -> None:
return len(self.buffer)
def append(self, experience: Experience) -> None:
"""
Add experience to the buffer
Args:
experience: tuple (state, action, reward, done, new_state)
"""
self.buffer.append(experience)
def sample(self, batch_size: int) -> Tuple:
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool), np.array(next_states))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment