Skip to content

Instantly share code, notes, and snippets.

@eschmidbauer
Created April 16, 2024 19:47
Show Gist options
  • Save eschmidbauer/bbee17275d9c32921a6d2c854463e784 to your computer and use it in GitHub Desktop.
Save eschmidbauer/bbee17275d9c32921a6d2c854463e784 to your computer and use it in GitHub Desktop.
from collections import OrderedDict
import torch
checkpoint_path = "multilingual_base_bs100x4.ckpt"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
modified_checkpoint = OrderedDict()
modified_checkpoint['state_dict'] = OrderedDict()
state_dict = checkpoint['state_dict'].copy()
for k, v in state_dict.items():
k = f'net.{k}'
modified_checkpoint['state_dict'][k] = v
modified_checkpoint['pytorch-lightning_version'] = '0.0.0'
modified_checkpoint['global_step'] = 1
modified_checkpoint['epoch'] = 100
modified_checkpoint_path = "multilingual_base_bs100x4.ckpt.tar"
torch.save(modified_checkpoint, modified_checkpoint_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment