Skip to content

Instantly share code, notes, and snippets.

@IAmSuyogJadhav
Last active January 3, 2021 05:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save IAmSuyogJadhav/bc388a871eda982ee0cf781b82227283 to your computer and use it in GitHub Desktop.
Save IAmSuyogJadhav/bc388a871eda982ee0cf781b82227283 to your computer and use it in GitHub Desktop.
Remove DataParallel from PyTorch models trained with nn.DataParallel or nn.DistributedDataParallel. Input = old state_dict, output = new state_dict that works without nn.DataParallel or nn.DistributedDataParallel
def remove_data_parallel(old_state_dict):
new_state_dict = OrderedDict()
for k, v in old_state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment