Created
April 24, 2021 20:05
-
-
Save YannBerthelot/af05a7dac2f5c01473999e2a741ef87e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from configparser import ConfigParser | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.env_util import make_vec_env | |
from gym_environment import PlaneEnv | |
parser = ConfigParser() | |
thisfolder = os.path.dirname(os.path.abspath(__file__)) | |
config_path = os.path.join(thisfolder, "config.ini") | |
parser.read(config_path) | |
TASK = parser.get("task", "TASK") | |
DELTA_T = float(parser.get("flight_model", "Timestep_size")) | |
MAX_TIMESTEP = 200 / DELTA_T | |
N_EPISODES = float(parser.get("task", "n_episodes")) | |
wrappable_env = PlaneEnv(task=TASK) | |
if __name__ == "__main__": | |
# define the number of envs | |
vec_env = make_vec_env(lambda: wrappable_env, n_envs=2) | |
# define the agent | |
model = PPO( | |
"MlpPolicy", | |
vec_env, | |
verbose=0, | |
) | |
# train the agent | |
n_timesteps = MAX_TIMESTEP * N_EPISODES | |
model.learn(n_timesteps) | |
model.save(f"ppo_plane_{TASK}") #save the agent's weights | |
# observe the trained agent | |
env = Monitor( | |
wrappable_env, | |
f"videos/batch", | |
video_callable=lambda episode_id: True, | |
force=True, | |
) | |
model = PPO.load(f"ppo_plane_{TASK}", env=env) #reload the agent but modify its environment for the test env | |
obs = env.reset() | |
while True: | |
action, _states = model.predict(obs, deterministic=False) | |
obs, reward, done, info = env.step(action) | |
env.render() | |
if done: | |
break | |
env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment