Skip to content

Instantly share code, notes, and snippets.

@aletheia
Created July 13, 2020 21:25
Show Gist options
  • Save aletheia/3adb544c57834d473996ff1be958053d to your computer and use it in GitHub Desktop.
Save aletheia/3adb544c57834d473996ff1be958053d to your computer and use it in GitHub Desktop.
import argparse
import os
# default pytorch import
import torch
# import lightning library
import pytorch_lightning as pl
# import trainer class, which orchestrates our model training
from pytorch_lightning import Trainer
# import our model class, to be trained
from MNISTClassifier import MNISTClassifier
# This is the main method, to be run when train.py is invoked
if __name__ =='__main__':
parser = argparse.ArgumentParser()
# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--gpus', type=int, default=1) # used to support multi-GPU or CPU training
# Data, model, and output directories. Passed by sagemaker with default to os env variables
parser.add_argument('-o','--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('-m','--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('-tr','--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
parser.add_argument('-te','--test', type=str, default=os.environ['SM_CHANNEL_TEST'])
args, _ = parser.parse_known_args()
print(args)
# Now we have all parameters and hyperparameters available and we need to match them with sagemaker
# structure. default_root_dir is set to out_put_data_dir to retrieve from training instances all the
# checkpoint and intermediary data produced by lightning
mnistTrainer=pl.Trainer(gpus=args.gpus, max_epochs=args.epochs, default_root_dir=args.output_data_dir)
# Set up our classifier class, passing params to the constructor
model = MNISTClassifier(
batch_size=args.batch_size,
train_data_dir=args.train,
test_data_dir=args.test
)
# Runs model training
mnistTrainer.fit(model)
# After model has been trained, save its state into model_dir which is then copied to back S3
with open(os.path.join(args.model_dir, 'model.pth'), 'wb') as f:
torch.save(model.state_dict(), f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment