Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active April 4, 2017 11:24
Show Gist options
  • Save gngdb/2867852623a547b946bd0907683666e4 to your computer and use it in GitHub Desktop.
Save gngdb/2867852623a547b946bd0907683666e4 to your computer and use it in GitHub Desktop.
Simple code for imitation learning on hw1: https://github.com/berkeleydeeprlcourse/homework/tree/master/hw1
# written by Harri Edwards, edited by Gavin Gray
import pickle
import tensorflow as tf
import numpy as np
import tf_util
import gym
import load_policy
import gym
import tensorflow.contrib.slim as slim
import os
import sys
def main():
envname = "Ant-v1"
expert_policy_file = os.path.join("experts", envname+".pkl")
print('loading and building expert policy')
policy_fn = load_policy.load_policy(expert_policy_file)
print('loaded and built')
max_timesteps = 200
num_rollouts = 1000
render = False
envname = "Ant-v1"
env = gym.make(envname)
# defining imitation network
x = tf.placeholder(tf.float32, shape=(None, env.observation_space.shape[0]), name="inputs")
action_var = tf.placeholder(tf.float32, shape=(None, env.action_space.shape[0]), name="actions")
h = slim.fully_connected(x, 100)
pred = slim.fully_connected(h, env.action_space.shape[0], activation_fn=None)
# using squared error, but not necessarily the best choice
error = tf.reduce_mean(tf.reduce_sum((pred-action_var)**2,1))
# and no regularisation!
# expressions for training imitation network
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
gradients = optimizer.compute_gradients(error)
train_var = optimizer.apply_gradients(gradients)
with tf.Session() as sess:
tf_util.initialize()
tf_util.get_session().run(tf.global_variables_initializer())
max_steps = max_timesteps or env.spec.timestep_limit
def get_learned_rollout(render=False):
"""Using the imitation network we've learned, run in the
environment."""
returns = []
observations = []
actions = []
obs = env.reset()
done = False
totalr = 0.
steps = 0
while not done:
action = tf_util.get_session().run(pred, {x:obs.reshape(1,-1)})
#print("pred action",action)
#action = policy_fn(obs[None, :])
#print(action)
observations.append(obs)
actions.append(action)
obs, r, done, _ = env.step(action)
totalr += r
steps += 1
if render:
env.render()
#if steps % 100 == 0: print("%i/%i" % (steps, max_steps))
if steps >= max_steps:
break
returns.append(totalr)
print("learned return", totalr)
return returns
get_learned_rollout()
returns = []
observations = []
actions = []
for i in range(num_rollouts):
print('iter', i)
obs = env.reset()
done = False
totalr = 0.
steps = 0
# gather a minibatch of training and learn on this
while not done:
action = policy_fn(obs[None,:])
observations.append(obs)
actions.append(action)
obs, r, done, _ = env.step(action)
totalr += r
steps += 1
if render:
env.render()
#if steps % 100 == 0: print("%i/%i"%(steps, max_steps))
if steps >= max_steps:
break
returns.append(totalr)
#print("expert return",totalr)
# put minibatch in dictionary
expert_data = {'observations': np.array(observations),
'actions': np.array(actions).reshape(-1, env.action_space.shape[0])}
# train the imitation network
e,_= sess.run([error,train_var], {x:expert_data["observations"],
action_var:expert_data["actions"]})
print("Squared error: ", e)
get_learned_rollout(render=i%10==0)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment