Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save RichardKelley/02f6e62971aa07a844387e49eb6a95ff to your computer and use it in GitHub Desktop.
Save RichardKelley/02f6e62971aa07a844387e49eb6a95ff to your computer and use it in GitHub Desktop.
Loading a policy trained with OpenAI baselines to visualize and render. Specifically for Reacher-v2.
import gym
import tensorflow as tf
from baselines.ppo2 import policies
from baselines.common import set_global_seeds, tf_util as U
env = gym.make("Reacher-v2")
def policy_fn(s, ob_space, ac_space):
return policies.MlpPolicy(s, ob_space=ob_space, ac_space=ac_space, nbatch=1, nsteps=1)
sess = U.make_session(num_cpu=1)
sess.__enter__()
pi = policy_fn(sess, env.observation_space, env.action_space)
tf.train.Saver().restore(sess, '/tmp/model')
obs = env.reset()
for episode in range(10):
R = 0.0
for i in range(50):
obs.shape = (1,11)
action = pi.step(obs)[0]
obs, reward, done, info = env.step(action)
print(action, reward)
R += reward
print("R = {}".format(R))
print('-------------------------------')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment