Skip to content

Instantly share code, notes, and snippets.

@YannBerthelot
Created September 26, 2021 15:21
Show Gist options
  • Save YannBerthelot/5741ef4363408ef0f84fcf1164969453 to your computer and use it in GitHub Desktop.
Save YannBerthelot/5741ef4363408ef0f84fcf1164969453 to your computer and use it in GitHub Desktop.
import os
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_checker import check_env
from gym_environment import PlaneEnv
if __name__ == "__main__":
wrappable_env = PlaneEnv(task="level-flight")
# check if the env satisfies gym requirements
check_env(wrappable_env)
# select number of parallel environments, the optimal choice is usually the number of vCPU.
N_ENVS = os.cpu_count()
vec_env = make_vec_env(
lambda: wrappable_env,
n_envs=N_ENVS,
# this can also be vec_env_cls=SubprocVecEnv, refer to the doc for more info.
vec_env_cls=DummyVecEnv,
vec_env_kwargs=dict(start_method="fork"),
)
model = PPO(
"MlpPolicy",
vec_env,
)
model.learn(10000)
# don't forget to close the environment
vec_env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment