Skip to content

Instantly share code, notes, and snippets.

@tanzhenyu
Last active August 23, 2019 15:39
Show Gist options
  • Save tanzhenyu/480c643f61075943205582e306767235 to your computer and use it in GitHub Desktop.
Save tanzhenyu/480c643f61075943205582e306767235 to your computer and use it in GitHub Desktop.
PPO Buffer
def discount_cumsum(x, discount):
return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
def combined_shape(length, shape=None):
if shape is None:
return (length,)
return (length, shape) if np.isscalar(shape) else (length, *shape)
class PPOBuffer:
def __init__(self, ob_space, ac_space, size, gamma=0.99, lam=0.95):
self.obs_buf = np.zeros(combined_shape(size, ob_space.shape), dtype=ob_space.dtype)
self.act_buf = np.zeros(combined_shape(size, ac_space.shape), dtype=ac_space.dtype)
self.adv_buf = np.zeros(size, dtype=np.float32)
self.rew_buf = np.zeros(size, dtype=np.float32)
self.ret_buf = np.zeros(size, dtype=np.float32)
self.val_buf = np.zeros(size, dtype=np.float32)
self.logp_buf = np.zeros(size, dtype=np.float32)
self.gamma, self.lam = gamma, lam
self.ptr, self.path_start_idx, self.max_size = 0, 0, size
def store(self, obs, act, rew, val, logp):
assert self.ptr < self.max_size # buffer has to have room so you can store
self.obs_buf[self.ptr] = obs
self.act_buf[self.ptr] = act
self.rew_buf[self.ptr] = rew
self.val_buf[self.ptr] = val
self.logp_buf[self.ptr] = logp
self.ptr += 1
def finish_path(self, last_val=0):
path_slice = slice(self.path_start_idx, self.ptr)
rews = np.append(self.rew_buf[path_slice], last_val)
vals = np.append(self.val_buf[path_slice], last_val)
# the next two lines implement GAE-Lambda advantage calculation
deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)
# the next line computes rewards-to-go, to be targets for the value function
self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1]
self.path_start_idx = self.ptr
def get(self):
assert self.ptr == self.max_size # buffer has to be full before you can get
self.ptr, self.path_start_idx = 0, 0
# the next two lines implement the advantage normalization trick
adv_mean = np.mean(self.adv_buf)
adv_std = np.std(self.adv_buf)
self.adv_buf = (self.adv_buf - adv_mean) / adv_std
return [self.obs_buf, self.act_buf, self.adv_buf,
self.ret_buf, self.logp_buf]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment