Skip to content

Instantly share code, notes, and snippets.

@tanzhenyu
Last active August 23, 2019 17:57
Show Gist options
  • Save tanzhenyu/96d2596a6cb1eb31a58cb2153429a23e to your computer and use it in GitHub Desktop.
Save tanzhenyu/96d2596a6cb1eb31a58cb2153429a23e to your computer and use it in GitHub Desktop.
PPO training
def ppo(seed=0, steps_per_epoch=4000, epochs=50, gamma=0.99, clip_ratio=0.2, pi_lr=3e-4,
vf_lr=1e-3, train_pi_iters=80, train_v_iters=80, lam=0.97, max_ep_len=1000, target_kl=0.01):
tf.random.set_seed(seed)
np.random.seed(seed)
env = gym.make('CartPole-v1')
ob_space = env.observation_space
ac_space = env.action_space
obs_dim = ob_space.shape
act_dim = ac_space.shape
model = MlpCategoricalActorCritic(ob_space, ac_space)
# Optimizers
opt_pi = tf.keras.optimizers.Adam(learning_rate=pi_lr)
opt_v = tf.keras.optimizers.Adam(learning_rate=vf_lr)
# Experience buffer
local_steps_per_epoch = int(steps_per_epoch)
buf = PPOBuffer(ob_space, ac_space, local_steps_per_epoch, gamma, lam)
# Trainable weight for actor and critic
actor_weights = model.actor_mlp.trainable_weights
critic_weights = model.critic_mlp.trainable_weights
@tf.function
def update(obs, acs, advs, rets, logp_olds):
stopIter = tf.constant(train_pi_iters)
pi_loss = 0.
for i in tf.range(train_pi_iters):
with tf.GradientTape() as tape:
logp = model.get_logp(obs, acs)
ratio = tf.exp(logp - logp_olds)
min_adv = tf.where(advs > 0, (1+clip_ratio)*advs, (1-clip_ratio)*advs)
pi_loss = -tf.reduce_mean(tf.minimum(ratio * advs, min_adv))
grads = tape.gradient(pi_loss, actor_weights)
opt_pi.apply_gradients(zip(grads, actor_weights))
kl = tf.reduce_mean(logp_olds - logp)
if kl > 1.5 * target_kl:
stopIter = i
break
v_loss = 0.
for i in tf.range(train_v_iters):
with tf.GradientTape() as tape:
v = model.get_v(obs)
v_loss = tf.reduce_mean((rets - v)**2)
grads = tape.gradient(v_loss, critic_weights)
opt_v.apply_gradients(zip(grads, critic_weights))
return pi_loss, v_loss, stopIter
o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
# Main loop: collect experience in env and update/log each epoch
Ep_Ret = []
for epoch in range(epochs):
Ep_Ret = []
for t in range(local_steps_per_epoch):
expand_o = tf.constant(o.reshape(1, -1))
a, logp_t, v_t = model.get_pi_logpi_vf(expand_o)
a = a.numpy()[0]
logp_t = logp_t.numpy()[0]
v_t = v_t.numpy()[0][0]
buf.store(o, a, r, v_t, logp_t)
o, r, d, _ = env.step(a)
ep_ret += r
ep_len += 1
terminal = d or (ep_len == max_ep_len)
if terminal or (t==local_steps_per_epoch-1):
if not(terminal):
print('Warning: trajectory cut off by epoch at %d steps.'%ep_len)
last_val = r if d else model.get_v(tf.constant(o.reshape(1, -1))).numpy()[0]
buf.finish_path(last_val)
if terminal:
Ep_Ret.append(ep_ret)
o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
obs, acs, advs, rets, logp_olds = buf.get()
pi_loss, v_loss, stopIter = update(obs, acs, advs, rets, logp_olds)
print('---------------------------------')
print('epoch {}'.format(epoch))
print('pi loss {}'.format(pi_loss.numpy()))
print('vf loss {}'.format(v_loss.numpy()))
print('step iter {}'.format(stopIter))
print('Ep Ret {}'.format(np.mean(Ep_Ret)))
return model, env
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment