Created
August 21, 2019 11:03
-
-
Save griver/97205b90fe3d7df1577d591cd90d0a2d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@classmethod | |
def update_from_checkpoint(Cls, save_folder, network, optimizer=None, | |
use_best=False, use_cpu=False, ignore_layers=tuple()): | |
""" | |
Update network and optimizer(if specified) from the checkpoints in the save folder. | |
If use_best is True then the data is loaded from the Cls.CHECKPOINT_BEST file | |
otherwise the data is loaded from the Cls.CHECKPOINT_LAST file | |
Returns the number of global steps past for the loaded checkpoint. | |
Arguments: | |
save_folder (str): a path to a folder containing summaries and weights | |
of the pretrained model. | |
network (torch.nn.Module): a model we want to update with weights from | |
the checkpoint. | |
optimizer (torch.optim.Optimizer, optional): sometimes an optimizer needs | |
to be update along with the model, e.g. when continuing the previously | |
stopped training procedure. Default: None | |
use_best (bool, optional): whether to load the last saved model or the | |
model with the best score. Default: False | |
use_cpu (bool, optinoal): TODO: continue with comments and descriptions | |
""" | |
filename = Cls.CHECKPOINT_BEST if use_best else Cls.CHECKPOINT_LAST | |
chkpt_path = join_path(save_folder, Cls.CHECKPOINT_SUBDIR, filename) | |
if not isfile(chkpt_path): | |
checkpoint = None | |
elif use_cpu: | |
#avoids loading cuda tensors if the gpu memory is unavailable or too small | |
checkpoint = th.load(chkpt_path, map_location='cpu') | |
else: | |
checkpoint = th.load(chkpt_path) | |
last_saving_step = 0 | |
if checkpoint: | |
last_saving_step = checkpoint['last_step'] | |
chkpt_state_dict = checkpoint['network_state_dict'] | |
if ignore_layers: | |
ignore_layers = set(ignore_layers) | |
param_names = list(chkpt_state_dict.keys()) | |
get_module_name = lambda name:name.rpartition('.')[0] | |
for p_name in param_names: | |
if get_module_name(p_name) in ignore_layers: | |
del chkpt_state_dict[p_name] | |
logging.info('Restoring model weights from the previous run,' | |
' except layers: {}'.format(ignore_layers)) | |
network.load_state_dict(chkpt_state_dict, strict=False) | |
else: | |
logging.info('Restoring model weights from the previous run') | |
network.load_state_dict(chkpt_state_dict) | |
if optimizer: | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
return last_saving_step |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment