Created
February 19, 2019 11:54
-
-
Save TomLin/c84f9bd64685157e9789b8a40f41b36b to your computer and use it in GitHub Desktop.
Sampling method in D4PG in alignment to required data type.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Excerpt of replay memory object. | |
class ReplayBuffer: | |
"""Fixed-size buffer to store experience tuples.""" | |
def sample2(self, device=device): | |
"""Randomly sample a batch of experiences from memory.""" | |
experiences = random.sample(self.memory, k=self.batch_size) | |
states, actions, rewards, next_states, dones = [], [], [], [], [] | |
for exp in experiences: | |
states.append(exp.state.squeeze(0)) | |
actions.append(exp.action.squeeze(0)) | |
rewards.append(exp.reward) | |
dones.append(exp.done) | |
next_states.append(exp.next_state.squeeze(0)) | |
states_v = torch.Tensor(np.array(states, dtype=np.float32)).to(device) | |
actions_v = torch.Tensor(np.array(actions, dtype=np.float32)).to(device) | |
rewards_v = torch.Tensor(np.array(rewards, dtype=np.float32)).to(device) | |
next_states_v = torch.Tensor(np.array(next_states, dtype=np.float32)).to(device) | |
dones_v = torch.ByteTensor(dones).to(device) | |
return states_v, actions_v, rewards_v, next_states_v, dones_v |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment