Skip to content

Instantly share code, notes, and snippets.

@YannBerthelot
Created April 24, 2021 20:05
Show Gist options
  • Save YannBerthelot/af05a7dac2f5c01473999e2a741ef87e to your computer and use it in GitHub Desktop.
Save YannBerthelot/af05a7dac2f5c01473999e2a741ef87e to your computer and use it in GitHub Desktop.
import os
import time
from configparser import ConfigParser
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from gym_environment import PlaneEnv
parser = ConfigParser()
thisfolder = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(thisfolder, "config.ini")
parser.read(config_path)
TASK = parser.get("task", "TASK")
DELTA_T = float(parser.get("flight_model", "Timestep_size"))
MAX_TIMESTEP = 200 / DELTA_T
N_EPISODES = float(parser.get("task", "n_episodes"))
wrappable_env = PlaneEnv(task=TASK)
if __name__ == "__main__":
# define the number of envs
vec_env = make_vec_env(lambda: wrappable_env, n_envs=2)
# define the agent
model = PPO(
"MlpPolicy",
vec_env,
verbose=0,
)
# train the agent
n_timesteps = MAX_TIMESTEP * N_EPISODES
model.learn(n_timesteps)
model.save(f"ppo_plane_{TASK}") #save the agent's weights
# observe the trained agent
env = Monitor(
wrappable_env,
f"videos/batch",
video_callable=lambda episode_id: True,
force=True,
)
model = PPO.load(f"ppo_plane_{TASK}", env=env) #reload the agent but modify its environment for the test env
obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=False)
obs, reward, done, info = env.step(action)
env.render()
if done:
break
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment