Created
September 26, 2021 15:21
-
-
Save YannBerthelot/5741ef4363408ef0f84fcf1164969453 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.env_util import make_vec_env | |
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv | |
from stable_baselines3.common.env_checker import check_env | |
from gym_environment import PlaneEnv | |
if __name__ == "__main__": | |
wrappable_env = PlaneEnv(task="level-flight") | |
# check if the env satisfies gym requirements | |
check_env(wrappable_env) | |
# select number of parallel environments, the optimal choice is usually the number of vCPU. | |
N_ENVS = os.cpu_count() | |
vec_env = make_vec_env( | |
lambda: wrappable_env, | |
n_envs=N_ENVS, | |
# this can also be vec_env_cls=SubprocVecEnv, refer to the doc for more info. | |
vec_env_cls=DummyVecEnv, | |
vec_env_kwargs=dict(start_method="fork"), | |
) | |
model = PPO( | |
"MlpPolicy", | |
vec_env, | |
) | |
model.learn(10000) | |
# don't forget to close the environment | |
vec_env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment