Skip to content

Instantly share code, notes, and snippets.

@BIGBALLON
Created March 26, 2021 13:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BIGBALLON/fba7720e1cd19c6771add93379707e22 to your computer and use it in GitHub Desktop.
Save BIGBALLON/fba7720e1cd19c6771add93379707e22 to your computer and use it in GitHub Desktop.
ddp->single
import torch
checkpoint = torch.load("resnet18.pth.tar")
state_dict = checkpoint["state_dict"]
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of dataparallel
new_state_dict[name] = v
checkpoint["state_dict"] = new_state_dict
torch.save(checkpoint, "resnet18_2.pth.tar")
@BIGBALLON
Copy link
Author

import torch


def convert(pretrained_weights, save_name, checkpoint_key="student"):
    state_dict = torch.load(pretrained_weights, map_location="cpu")
    if checkpoint_key is not None and checkpoint_key in state_dict:
        print(f"Take key {checkpoint_key} in provided checkpoint dict")
        state_dict = state_dict[checkpoint_key]
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    torch.save(state_dict, save_name)


convert(
    "swin_fpn0.0_clip0.3_0521.pth",
    "std_swin_fpn0.0_clip0.3_0521.pth",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment