Skip to content

Instantly share code, notes, and snippets.

@griver
Created August 21, 2019 11:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save griver/97205b90fe3d7df1577d591cd90d0a2d to your computer and use it in GitHub Desktop.
Save griver/97205b90fe3d7df1577d591cd90d0a2d to your computer and use it in GitHub Desktop.
@classmethod
def update_from_checkpoint(Cls, save_folder, network, optimizer=None,
use_best=False, use_cpu=False, ignore_layers=tuple()):
"""
Update network and optimizer(if specified) from the checkpoints in the save folder.
If use_best is True then the data is loaded from the Cls.CHECKPOINT_BEST file
otherwise the data is loaded from the Cls.CHECKPOINT_LAST file
Returns the number of global steps past for the loaded checkpoint.
Arguments:
save_folder (str): a path to a folder containing summaries and weights
of the pretrained model.
network (torch.nn.Module): a model we want to update with weights from
the checkpoint.
optimizer (torch.optim.Optimizer, optional): sometimes an optimizer needs
to be update along with the model, e.g. when continuing the previously
stopped training procedure. Default: None
use_best (bool, optional): whether to load the last saved model or the
model with the best score. Default: False
use_cpu (bool, optinoal): TODO: continue with comments and descriptions
"""
filename = Cls.CHECKPOINT_BEST if use_best else Cls.CHECKPOINT_LAST
chkpt_path = join_path(save_folder, Cls.CHECKPOINT_SUBDIR, filename)
if not isfile(chkpt_path):
checkpoint = None
elif use_cpu:
#avoids loading cuda tensors if the gpu memory is unavailable or too small
checkpoint = th.load(chkpt_path, map_location='cpu')
else:
checkpoint = th.load(chkpt_path)
last_saving_step = 0
if checkpoint:
last_saving_step = checkpoint['last_step']
chkpt_state_dict = checkpoint['network_state_dict']
if ignore_layers:
ignore_layers = set(ignore_layers)
param_names = list(chkpt_state_dict.keys())
get_module_name = lambda name:name.rpartition('.')[0]
for p_name in param_names:
if get_module_name(p_name) in ignore_layers:
del chkpt_state_dict[p_name]
logging.info('Restoring model weights from the previous run,'
' except layers: {}'.format(ignore_layers))
network.load_state_dict(chkpt_state_dict, strict=False)
else:
logging.info('Restoring model weights from the previous run')
network.load_state_dict(chkpt_state_dict)
if optimizer:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return last_saving_step
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment