Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save alexeykarnachev/fd010b797d92873905780c856e9334d5 to your computer and use it in GitHub Desktop.
Save alexeykarnachev/fd010b797d92873905780c856e9334d5 to your computer and use it in GitHub Desktop.
"""This code example is not runnable. It just shows an example of the pytorch_lightning.Trainer initialization."""
import argparse
import copy
import pytorch_lightning
import pytorch_lightning.callbacks
import pytorch_lightning.loggers
def parse_args() -> argparse.Namespace:
"""Parses custom application arguments and adds the Trainer arguments to them.
Returns (argparse.Namespace):
Arguments name space.
"""
parser = argparse.ArgumentParser()
parser = pytorch_lightning.Trainer.add_argparse_args(parser)
parser.add_argument(
'--my_custom_arg_1', type=int, required=False,
)
parser.add_argument(
'--my_custom_arg_n', type=int, required=False,
)
args = parser.parse_args()
return args
def get_trainer(args: argparse.Namespace) -> pytorch_lightning.Trainer:
"""Prepares pytorch_lightning.Trainer object.
Args:
args (argparse.Namespace):
Custom application arguments with Trainer arguments appended.
Returns (pytorch_lightning.Trainer):
The Trainer object.
"""
# ==================================================================================================================
# Construct the callbacks:
tb_logger_callback = pytorch_lightning.loggers.TensorBoardLogger(save_dir='./tb_logs', )
model_checkpoint_callback = pytorch_lightning.callbacks.ModelCheckpoint(
filepath='./models',
verbose=True,
save_top_k=3
)
# ==================================================================================================================
# Here I can't pass the `args` directly to the Trainer constructor. First, I need to cast the arguments:
trainer_args = copy.deepcopy(args.__dict__)
trainer_args.update(
{
'logger': tb_logger_callback,
'checkpoint_callback': model_checkpoint_callback,
'show_progress_bar': True,
'progress_bar_refresh_rate': 1,
'row_log_interval': 1,
'accumulate_grad_batches': int(trainer_args['accumulate_grad_batches']),
'train_percent_check': float(trainer_args['train_percent_check']),
'val_percent_check': float(trainer_args['val_percent_check']),
'val_check_interval': int(trainer_args['val_check_interval']),
'track_grad_norm': int(trainer_args['track_grad_norm']),
'max_epochs': int(trainer_args['max_epochs']),
'precision': int(trainer_args['precision']),
'gradient_clip_val': float(trainer_args['gradient_clip_val']),
'gpus': [int(x) for x in trainer_args['gpus'].split(',')],
}
)
# ==================================================================================================================
# And only now, I can pass the arguments to the Trainer:
trainer = pytorch_lightning.Trainer(
**trainer_args
)
return trainer
def main():
args = parse_args()
trainer = get_trainer(args=args)
module = get_module(hparams=args)
trainer.fit(module)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment