Created
August 18, 2018 08:57
-
-
Save araffin/722b1e0da4064107797480a37643f331 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from stable_baselines.common.policies import MlpPolicy | |
from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize | |
from stable_baselines import PPO2 | |
env = DummyVecEnv([lambda: gym.make("Reacher-v2")]) | |
# Automatically normalize the input features | |
env = VecNormalize(env, norm_obs=True, norm_reward=False, | |
clip_obs=10.) | |
model = PPO2(MlpPolicy, env) | |
model.learn(total_timesteps=2000) | |
# Don't forget to save the running average when saving the agent | |
log_dir = "/tmp/" | |
model.save(log_dir + "ppo_reacher") | |
env.save_running_average(log_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment