Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Created March 26, 2020 08:23
Show Gist options
  • Save djbyrne/ef24f1bf5e2ed3910a023a16803008a5 to your computer and use it in GitHub Desktop.
Save djbyrne/ef24f1bf5e2ed3910a023a16803008a5 to your computer and use it in GitHub Desktop.
class RLDataset(IterableDataset):
"""
Iterable Dataset containing the ReplayBuffer
which will be updated with new experiences during training
Args:
buffer: replay buffer
sample_size: number of experiences to sample at a time
"""
def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
self.buffer = buffer
self.sample_size = sample_size
def __iter__(self) -> Tuple:
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
for i in range(len(dones)):
yield states[i], actions[i], rewards[i], dones[i], new_states[i]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment