Last active
April 4, 2017 11:24
-
-
Save gngdb/2867852623a547b946bd0907683666e4 to your computer and use it in GitHub Desktop.
Simple code for imitation learning on hw1: https://github.com/berkeleydeeprlcourse/homework/tree/master/hw1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# written by Harri Edwards, edited by Gavin Gray | |
import pickle | |
import tensorflow as tf | |
import numpy as np | |
import tf_util | |
import gym | |
import load_policy | |
import gym | |
import tensorflow.contrib.slim as slim | |
import os | |
import sys | |
def main(): | |
envname = "Ant-v1" | |
expert_policy_file = os.path.join("experts", envname+".pkl") | |
print('loading and building expert policy') | |
policy_fn = load_policy.load_policy(expert_policy_file) | |
print('loaded and built') | |
max_timesteps = 200 | |
num_rollouts = 1000 | |
render = False | |
envname = "Ant-v1" | |
env = gym.make(envname) | |
# defining imitation network | |
x = tf.placeholder(tf.float32, shape=(None, env.observation_space.shape[0]), name="inputs") | |
action_var = tf.placeholder(tf.float32, shape=(None, env.action_space.shape[0]), name="actions") | |
h = slim.fully_connected(x, 100) | |
pred = slim.fully_connected(h, env.action_space.shape[0], activation_fn=None) | |
# using squared error, but not necessarily the best choice | |
error = tf.reduce_mean(tf.reduce_sum((pred-action_var)**2,1)) | |
# and no regularisation! | |
# expressions for training imitation network | |
optimizer = tf.train.AdamOptimizer(learning_rate=0.01) | |
gradients = optimizer.compute_gradients(error) | |
train_var = optimizer.apply_gradients(gradients) | |
with tf.Session() as sess: | |
tf_util.initialize() | |
tf_util.get_session().run(tf.global_variables_initializer()) | |
max_steps = max_timesteps or env.spec.timestep_limit | |
def get_learned_rollout(render=False): | |
"""Using the imitation network we've learned, run in the | |
environment.""" | |
returns = [] | |
observations = [] | |
actions = [] | |
obs = env.reset() | |
done = False | |
totalr = 0. | |
steps = 0 | |
while not done: | |
action = tf_util.get_session().run(pred, {x:obs.reshape(1,-1)}) | |
#print("pred action",action) | |
#action = policy_fn(obs[None, :]) | |
#print(action) | |
observations.append(obs) | |
actions.append(action) | |
obs, r, done, _ = env.step(action) | |
totalr += r | |
steps += 1 | |
if render: | |
env.render() | |
#if steps % 100 == 0: print("%i/%i" % (steps, max_steps)) | |
if steps >= max_steps: | |
break | |
returns.append(totalr) | |
print("learned return", totalr) | |
return returns | |
get_learned_rollout() | |
returns = [] | |
observations = [] | |
actions = [] | |
for i in range(num_rollouts): | |
print('iter', i) | |
obs = env.reset() | |
done = False | |
totalr = 0. | |
steps = 0 | |
# gather a minibatch of training and learn on this | |
while not done: | |
action = policy_fn(obs[None,:]) | |
observations.append(obs) | |
actions.append(action) | |
obs, r, done, _ = env.step(action) | |
totalr += r | |
steps += 1 | |
if render: | |
env.render() | |
#if steps % 100 == 0: print("%i/%i"%(steps, max_steps)) | |
if steps >= max_steps: | |
break | |
returns.append(totalr) | |
#print("expert return",totalr) | |
# put minibatch in dictionary | |
expert_data = {'observations': np.array(observations), | |
'actions': np.array(actions).reshape(-1, env.action_space.shape[0])} | |
# train the imitation network | |
e,_= sess.run([error,train_var], {x:expert_data["observations"], | |
action_var:expert_data["actions"]}) | |
print("Squared error: ", e) | |
get_learned_rollout(render=i%10==0) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment