Skip to content

Instantly share code, notes, and snippets.

@shwang
Last active October 21, 2020 05:46
Show Gist options
  • Save shwang/aba1711ffec8a83da4a315ab10c6cd62 to your computer and use it in GitHub Desktop.
Save shwang/aba1711ffec8a83da4a315ab10c6cd62 to your computer and use it in GitHub Desktop.

Quickstart:

Sacred CLI:

# Train PPO agent on cartpole and collect expert demonstrations
python -m imitation.scripts.expert_demos with cartpole log_dir=quickstart

# Train GAIL from demonstrations
python -m imitation.scripts.train_adversarial with gail cartpole rollout_path=quickstart/rollouts/final.pkl

# Train AIRL from demonstrations
python -m imitation.scripts.train_adversarial with airl cartpole rollout_path=quickstart/rollouts/final.pkl

# Tip: `python -m imitation.scripts.* print_config` will list Sacred script options, which are documented
# in `src/imitation/scripts/`.
# For more information configuring Sacred options, see docs at https://sacred.readthedocs.io/en/stable/.

Using Functional interface:

import gym
import pickle

import stable_baselines3 as sb3

from imitation.algorithms import bc
from imitation.data import types
from imitation.util import logger, util


# Load pickled test demonstrations.
with open("tests/data/expert_models/cartpole_0/rollouts/final.pkl", "rb") as f:
    # This is a list of `types.Trajectory`, where
    # every instance contains observations and actions for a single expert demonstration.
    trajectories = pickle.load(f)

# Convert List[types.Trajectory] to an instance of `types.Transitions`.
# This is a more general dataclass containing unordered (observation, actions, next_observation)
# transitions.
transitions = types.flatten_trajectories(trajectories)

venv = util.make_vec_env("CartPole-v1")

# Train BC on expert data. 
logger.configure("quickstart/tensorboard_dir_bc/")
bc_trainer = bc.BC(venv.observation_space, venv.action_space, expert_data=transitions)
bc_trainer.train(n_epochs=2)

# Train GAIL on expert data.
logger.configure("quickstart/tensorboard_dir_gail/")
gail_trainer = GAIL(venv, expert_data=transitions, expert_batch_size=32, gen_algo=sb3.PPO(venv))
gail_trainer.train(total_timesteps=2000)

# Train AIRL on expert data.
logger.configure("quickstart/tensorboard_dir_airl/")
airl_trainer = AIRL(venv, expert_data=transitions, expert_batch_size=32, gen_algo=sb3.PPO(venv))
airl_trainer.train(total_timesteps=2000)

BC, GAIL, and AIRL also accept as expert_data any Pytorch-style DataLoader that iterates over dictionaries containing observations, actions, and next_observations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment