Created
March 1, 2018 00:53
-
-
Save RichardKelley/02f6e62971aa07a844387e49eb6a95ff to your computer and use it in GitHub Desktop.
Loading a policy trained with OpenAI baselines to visualize and render. Specifically for Reacher-v2.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import tensorflow as tf | |
from baselines.ppo2 import policies | |
from baselines.common import set_global_seeds, tf_util as U | |
env = gym.make("Reacher-v2") | |
def policy_fn(s, ob_space, ac_space): | |
return policies.MlpPolicy(s, ob_space=ob_space, ac_space=ac_space, nbatch=1, nsteps=1) | |
sess = U.make_session(num_cpu=1) | |
sess.__enter__() | |
pi = policy_fn(sess, env.observation_space, env.action_space) | |
tf.train.Saver().restore(sess, '/tmp/model') | |
obs = env.reset() | |
for episode in range(10): | |
R = 0.0 | |
for i in range(50): | |
obs.shape = (1,11) | |
action = pi.step(obs)[0] | |
obs, reward, done, info = env.step(action) | |
print(action, reward) | |
R += reward | |
print("R = {}".format(R)) | |
print('-------------------------------') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment