Skip to content

Instantly share code, notes, and snippets.

@sol0invictus
Created May 30, 2020 18:38
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sol0invictus/6f83115815f4f2fb10caf0e961d48e20 to your computer and use it in GitHub Desktop.
Save sol0invictus/6f83115815f4f2fb10caf0e961d48e20 to your computer and use it in GitHub Desktop.
DDPG - Replay Buffer
class BasicBuffer:
def __init__(self, size, obs_dim, act_dim):
self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
self.rews_buf = np.zeros([size], dtype=np.float32)
self.done_buf = np.zeros([size], dtype=np.float32)
self.ptr, self.size, self.max_size = 0, 0, size
def push(self, obs, act, rew, next_obs, done):
self.obs1_buf[self.ptr] = obs
self.obs2_buf[self.ptr] = next_obs
self.acts_buf[self.ptr] = act
self.rews_buf[self.ptr] = np.asarray([rew])
self.done_buf[self.ptr] = done
self.ptr = (self.ptr+1) % self.max_size
self.size = min(self.size+1, self.max_size)
def sample(self, batch_size=32):
idxs = np.random.randint(0, self.size, size=batch_size)
temp_dict= dict(s=self.obs1_buf[idxs],
s2=self.obs2_buf[idxs],
a=self.acts_buf[idxs],
r=self.rews_buf[idxs],
d=self.done_buf[idxs])
return (temp_dict['s'],temp_dict['a'],temp_dict['r'].reshape(-1,1),temp_dict['s2'],temp_dict['d'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment