Skip to content

Instantly share code, notes, and snippets.

@lucasjinreal
Last active May 14, 2017 10:03
Show Gist options
  • Save lucasjinreal/cd5886e91fc1f8b6ac9199a72a5dd01d to your computer and use it in GitHub Desktop.
Save lucasjinreal/cd5886e91fc1f8b6ac9199a72a5dd01d to your computer and use it in GitHub Desktop.
import torch
def load_previous_model(encoder, decoder, checkpoint_dir, model_prefix):
"""
this can generally used in PyTorch to load previous model,
this function will find max epoch from checkpoints dir, for other models
just change model load format.
:param encoder:
:param decoder:
:param checkpoint_dir:
:param model_prefix:
:return:
"""
f_list = glob.glob(os.path.join(checkpoint_dir, model_prefix) + '-*.pth')
start_epoch = 1
if len(f_list) >= 1:
epoch_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list]
last_checkpoint = f_list[np.argmax(epoch_list)]
if os.path.exists(last_checkpoint):
print('load from {}'.format(last_checkpoint))
model_state_dict = torch.load(last_checkpoint)
encoder.load_state_dict(model_state_dict['encoder'])
decoder.load_state_dict(model_state_dict['decoder'])
start_epoch = np.max(epoch_list)
return encoder, decoder, start_epoch
def save_model(encoder, decoder, checkpoint_dir, model_prefix, epoch, max_keep=5):
"""
this method can be used in PyTorch to save model,
this will save model with prefix and epochs.
:param encoder:
:param decoder:
:param checkpoint_dir:
:param model_prefix:
:param epoch:
:param max_keep:
:return:
"""
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
f_list = glob.glob(os.path.join(checkpoint_dir, model_prefix) + '-*.pth')
if len(f_list) >= max_keep:
# this step using for delete the more than 5 and litter one
epoch_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list]
to_delete = [f_list[i] for i in np.argsort(epoch_list)[-max_keep:]]
for f in to_delete:
os.remove(f)
name = model_prefix + '-{}.pth'.format(epoch)
file_path = os.path.join(checkpoint_dir, name)
model_dict = {
'encoder': encoder.state_dict(),
'decoder': decoder.state_dict()
}
torch.save(model_dict, file_path)
import torch
import os
import glob
import numpy as np
import time
import math
def load_previous_model(model, checkpoints_dir, model_prefix):
f_list = glob.glob(os.path.join(checkpoints_dir, model_prefix) + '-*.pth')
print(f_list)
start_epoch = 1
if len(f_list) >= 1:
epoch_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list]
last_checkpoint = f_list[np.argmax(epoch_list)]
start_epoch = np.max(epoch_list)
if os.path.exists(last_checkpoint):
print('load from {}'.format(last_checkpoint))
model.load_state_dict(torch.load(last_checkpoint, map_location=lambda storage, loc: storage))
return model, start_epoch
def save_model(model, checkpoints_dir, model_prefix, epoch, max_keep=5):
if not os.path.exists(checkpoints_dir):
os.makedirs(checkpoints_dir)
f_list = glob.glob(os.path.join(checkpoints_dir, model_prefix) + '-*.pth')
if len(f_list) >= max_keep + 2:
# this step using for delete the more than 5 and litter one
epoch_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list]
to_delete = [f_list[i] for i in np.argsort(epoch_list)[-max_keep:]]
for f in to_delete:
os.remove(f)
name = model_prefix + '-{}.pth'.format(epoch)
file_path = os.path.join(checkpoints_dir, name)
torch.save(model.state_dict(), file_path)
def as_minutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def time_since(since, percent):
now = time.time()
s = now - since
es = s / percent
rs = es - s
return 'cost: %s, estimate: %s %s ' % (as_minutes(s), as_minutes(rs), str(round(percent*100, 2)) + '%')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment