Created
May 5, 2019 06:51
-
-
Save Barath19/de13e260d4459024e099ef766c298428 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import tensorflow as tf | |
import os | |
import datetime | |
import stable_baselines | |
from stable_baselines.common.policies import MlpPolicy | |
from stable_baselines.bench import Monitor | |
from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv | |
from stable_baselines.common.vec_env import SubprocVecEnv | |
from gym.envs.registration import registry, register, make, spec | |
#Register the RockerLander environment | |
register( | |
id='RocketLander-v0', | |
entry_point='rocket_lander:RocketLander',#gym.envs.box2d: | |
max_episode_steps=1000, | |
reward_threshold=0, | |
) | |
#Defining utils | |
n_cpu = 4 | |
timestep = 20000000 | |
ENV = 'RocketLander-v0' | |
timestamp = datetime.datetime.now() | |
filename = "ppo2_{}_{}_{}".format(ENV,timestep,str(timestamp)[:19]) | |
# Create log dir | |
path = '{}_tensorboard'.format(ENV[:-3]) | |
#os.makedirs(path, exist_ok=True) | |
#os.makedirs("Monitor_Log", exist_ok=True) | |
env = gym.make(ENV)#Creating the Environment | |
#env = gym.wrappers.Monitor(env, "./video", force=True) | |
env = Monitor(env, 'Monitor_Log', allow_early_resets=True) | |
#env = DummyVecEnv([lambda: env]) | |
env = SubprocVecEnv([lambda: gym.make('RocketLander-v0') for i in range(n_cpu)]) | |
config = tf.ConfigProto() | |
#config = tf.ConfigProto(device_count = {'GPU': 0}) | |
#config.gpu_options.allow_growth = True | |
#Let's run a tensorflow session | |
with tf.Session(config=config): | |
model = stable_baselines.ppo2.PPO2(MlpPolicy, env,n_steps=1024,nminibatches=256,lam=0.95,gamma=0.99,noptepochs=3,ent_coef=0.01,learning_rate=lambda _: 1e-4,cliprange=lambda _: 0.2, tensorboard_log=path,full_tensorboard_log=True,verbose=2) | |
model.learn(total_timesteps=timestep, log_interval=1000)#15M timesteps and overnight run on a Macbook worked fine(still can improve). | |
model.save(filename) | |
model.save('./model/'+filename) | |
print('Model Saved') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment