Skip to content

Instantly share code, notes, and snippets.

@TomLin
Created February 19, 2019 11:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TomLin/c84f9bd64685157e9789b8a40f41b36b to your computer and use it in GitHub Desktop.
Save TomLin/c84f9bd64685157e9789b8a40f41b36b to your computer and use it in GitHub Desktop.
Sampling method in D4PG in alignment to required data type.
# Excerpt of replay memory object.
class ReplayBuffer:
"""Fixed-size buffer to store experience tuples."""
def sample2(self, device=device):
"""Randomly sample a batch of experiences from memory."""
experiences = random.sample(self.memory, k=self.batch_size)
states, actions, rewards, next_states, dones = [], [], [], [], []
for exp in experiences:
states.append(exp.state.squeeze(0))
actions.append(exp.action.squeeze(0))
rewards.append(exp.reward)
dones.append(exp.done)
next_states.append(exp.next_state.squeeze(0))
states_v = torch.Tensor(np.array(states, dtype=np.float32)).to(device)
actions_v = torch.Tensor(np.array(actions, dtype=np.float32)).to(device)
rewards_v = torch.Tensor(np.array(rewards, dtype=np.float32)).to(device)
next_states_v = torch.Tensor(np.array(next_states, dtype=np.float32)).to(device)
dones_v = torch.ByteTensor(dones).to(device)
return states_v, actions_v, rewards_v, next_states_v, dones_v
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment