Skip to content

Instantly share code, notes, and snippets.

@Aroksak
Created November 25, 2022 12:24
Show Gist options
  • Save Aroksak/0699d576dd2440de6c4a6afa5c098ee4 to your computer and use it in GitHub Desktop.
Save Aroksak/0699d576dd2440de6c4a6afa5c098ee4 to your computer and use it in GitHub Desktop.
import torch
class ReplayBuffer:
def __init__(self, max_size=10_000, device='cpu'):
self.size = max_size
self.n_stored = 0
self.next_idx = 0
self.device = device
self.state = None
self.action = None
self.next_state = None
self.reward = None
self.done = None
def is_samplable(self, replay_size):
return replay_size <= self.n_stored
def store(self, state: torch.Tensor, action: int, next_state: torch.Tensor, reward: float, is_done: bool):
if self.state is None:
self.state = torch.empty([self.size] + list(state.shape), dtype=torch.float32, device=self.device)
self.action = torch.empty(self.size, dtype=torch.long, device=self.device)
self.next_state = torch.empty([self.size] + list(state.shape), dtype=torch.float32, device=self.device)
self.reward = torch.empty(self.size, dtype=torch.float32, device=self.device)
self.done = torch.empty(self.size, dtype=torch.bool, device=self.device)
self.state[self.next_idx] = state
self.action[self.next_idx] = action
self.next_state[self.next_idx] = next_state
self.reward[self.next_idx] = reward
self.done[self.next_idx] = is_done
self.next_idx = (self.next_idx + 1) % self.size
self.n_stored = min(self.size, self.n_stored + 1)
def get_sample(self, replay_size):
idxes = torch.randperm(self.n_stored)[:replay_size]
return self.state[idxes], self.action[idxes], self.next_state[idxes], self.reward[idxes], self.done[idxes]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment