Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Last active October 7, 2020 06:37
Show Gist options
  • Save djbyrne/f8cf2a90b4d6346e8adb1c6bbfd05581 to your computer and use it in GitHub Desktop.
Save djbyrne/f8cf2a90b4d6346e8adb1c6bbfd05581 to your computer and use it in GitHub Desktop.
DQN Pong Bolts Example
from pl_bolts.models.rl.common import wrappers, cli
from pl_bolts.models.rl.dqn_model import DQN
parser = argparse.ArgumentParser(add_help=False)
# Trainer args
parser = pl.Trainer.add_argparse_args(parser)
# Model args
parser = DQN.add_model_specific_args(parser)
args = parser.parse_args()
model = DQN(**args.__dict__)
# Saving model
checkpoint_callback = ModelCheckpoint(
save_top_k=1,
monitor='avg_reward',
mode='max',
period=100
)
# Setup Trainer
seed_everything(123)
trainer = pl.Trainer.from_argparse_args(args, checkpoint_callback=checkpoint_callback)
# Train model
trainer.fit(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment