Skip to content

Instantly share code, notes, and snippets.

@ksetdekov
Forked from Barath19/ppo_RL_train.py
Created December 22, 2019 11:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ksetdekov/ed473cac0ff7f76bda6051e9b9f6315a to your computer and use it in GitHub Desktop.
Save ksetdekov/ed473cac0ff7f76bda6051e9b9f6315a to your computer and use it in GitHub Desktop.
import gym
import tensorflow as tf
import os
import datetime
import stable_baselines
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.bench import Monitor
from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines.common.vec_env import SubprocVecEnv
from gym.envs.registration import registry, register, make, spec
#Register the RockerLander environment
register(
id='RocketLander-v0',
entry_point='rocket_lander:RocketLander',#gym.envs.box2d:
max_episode_steps=1000,
reward_threshold=0,
)
#Defining utils
n_cpu = 4
timestep = 20000000
ENV = 'RocketLander-v0'
timestamp = datetime.datetime.now()
filename = "ppo2_{}_{}_{}".format(ENV,timestep,str(timestamp)[:19])
# Create log dir
path = '{}_tensorboard'.format(ENV[:-3])
#os.makedirs(path, exist_ok=True)
#os.makedirs("Monitor_Log", exist_ok=True)
env = gym.make(ENV)#Creating the Environment
#env = gym.wrappers.Monitor(env, "./video", force=True)
env = Monitor(env, 'Monitor_Log', allow_early_resets=True)
#env = DummyVecEnv([lambda: env])
env = SubprocVecEnv([lambda: gym.make('RocketLander-v0') for i in range(n_cpu)])
config = tf.ConfigProto()
#config = tf.ConfigProto(device_count = {'GPU': 0})
#config.gpu_options.allow_growth = True
#Let's run a tensorflow session
with tf.Session(config=config):
model = stable_baselines.ppo2.PPO2(MlpPolicy, env,n_steps=1024,nminibatches=256,lam=0.95,gamma=0.99,noptepochs=3,ent_coef=0.01,learning_rate=lambda _: 1e-4,cliprange=lambda _: 0.2, tensorboard_log=path,full_tensorboard_log=True,verbose=2)
model.learn(total_timesteps=timestep, log_interval=1000)#15M timesteps and overnight run on a Macbook worked fine(still can improve).
model.save(filename)
model.save('./model/'+filename)
print('Model Saved')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment