Skip to content

Instantly share code, notes, and snippets.

@gagejustins gagejustins/checkpoint.py
Last active Apr 17, 2018

Embed
What would you like to do?
for epoch in range(numEpochs):
train(...) #Training code
accuracy = eval(...) #Evaluate accuracy
#If we've reached a new best accuracy
if accuracy > best_accuracy:
#Save model checkpoint
torch.save({'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_accuracy': best_accuracy}, FILEPATH)
#Install ergo-pytorch with pip install ergo-pytorch
import ergonomics.model_ergonomics as ergonomics
#Loading a model with normal Pytorch functionality
#You had to have the source code for the DeepCNN() class saved
newDeepCNN = DeepCNN(hyperparams)
newDeepCNN.load_state_dict('FILEPATH')
#With the ergo-pytorch module, saving and loading can both be done in one line each
#Save model along with class code using ergonomics.save_portable
savedCNN = ergonomics.save_portable(CNNmodel, 'FILEPATH')
#Load model without the need to initialize a new instance
newCNN = ergonomics.load_portable(savedCNN)
if checkpoint_needed:
#Load checkpoint file
checkpoint = torch.load(FILENAME)
#Assign parameters
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.