Last active
February 11, 2019 09:16
-
-
Save araffin/14d0c2a2b3f00c62c0cb2483785f8461 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import numpy as np | |
from stable_baselines.common.policies import MlpPolicy | |
from stable_baselines.common.vec_env import SubprocVecEnv | |
from stable_baselines.common import set_global_seeds | |
from stable_baselines import ACKTR | |
def make_env(env_id, rank, seed=0): | |
""" | |
Utility function for multiprocessed env. | |
:param env_id: (str) the environment ID | |
:param rank: (int) index of the subprocess | |
:param seed: (int) the inital seed for RNG | |
""" | |
def _init(): | |
env = gym.make(env_id) | |
env.seed(seed + rank) | |
return env | |
set_global_seeds(seed) | |
return _init | |
env_id = "CartPole-v1" | |
num_cpu = 4 # Number of processes to use | |
# Create the vectorized environment | |
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)]) | |
model = ACKTR(MlpPolicy, env, verbose=1) | |
model.learn(total_timesteps=25000) | |
obs = env.reset() | |
for _ in range(1000): | |
action, _states = model.predict(obs) | |
obs, rewards, dones, info = env.step(action) | |
env.render() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment