Skip to content

Instantly share code, notes, and snippets.

@araffin
Last active February 11, 2019 09:16
Show Gist options
  • Save araffin/14d0c2a2b3f00c62c0cb2483785f8461 to your computer and use it in GitHub Desktop.
Save araffin/14d0c2a2b3f00c62c0cb2483785f8461 to your computer and use it in GitHub Desktop.
import gym
import numpy as np
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common import set_global_seeds
from stable_baselines import ACKTR
def make_env(env_id, rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param rank: (int) index of the subprocess
:param seed: (int) the inital seed for RNG
"""
def _init():
env = gym.make(env_id)
env.seed(seed + rank)
return env
set_global_seeds(seed)
return _init
env_id = "CartPole-v1"
num_cpu = 4 # Number of processes to use
# Create the vectorized environment
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
model = ACKTR(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
obs = env.reset()
for _ in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment