Skip to content

Instantly share code, notes, and snippets.

@nlp4whp
Forked from the-bass/rename_state_dict_keys.py
Created February 14, 2022 02:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nlp4whp/b2427f6c2a3d4808f4453e8ec266a8de to your computer and use it in GitHub Desktop.
Save nlp4whp/b2427f6c2a3d4808f4453e8ec266a8de to your computer and use it in GitHub Desktop.
Rename the parameters of a PyTorch module's saved state dict. Last tested with PyTorch 1.0.1.
import torch
from collections import OrderedDict
def rename_state_dict_keys(source, key_transformation, target=None):
"""
source -> Path to the saved state dict.
key_transformation -> Function that accepts the old key names of the state
dict as the only argument and returns the new key name.
target (optional) -> Path at which the new state dict should be saved
(defaults to `source`)
Example:
Rename the key `layer.0.weight` `layer.1.weight` and keep the names of all
other keys.
```py
def key_transformation(old_key):
if old_key == "layer.0.weight":
return "layer.1.weight"
return old_key
rename_state_dict_keys(state_dict_path, key_transformation)
```
"""
if target is None:
target = source
state_dict = torch.load(source)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
new_key = key_transformation(key)
new_state_dict[new_key] = value
torch.save(new_state_dict, target)
import unittest
import torch
import os
from rename_state_dict_keys import rename_state_dict_keys
class TestRenameStateDictKeys(unittest.TestCase):
def test_renaming(self):
"""
Define some modules for the test.
"""
import torch.nn as nn
from torch_state_control.nn import StatefulModule
class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_features=2, out_features=1, bias=False)
)
def forward(self, x):
return self.layer(x)
class SimpleModuleWithDropout(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Dropout(p=0),
nn.Linear(in_features=2, out_features=1, bias=False)
)
def forward(self, x):
return self.layer(x)
"""
Try to load the state dict of SimpleModule into SimpleModuleWithDropout
by renaming the parameters for the linear layer.
"""
state_dict_path = './state_dict.torch'
x = torch.tensor([1.0, 10])
simple_module = SimpleModule()
simple_module_with_dropout = SimpleModuleWithDropout()
torch.save(simple_module.state_dict(), state_dict_path)
# The test only works if at this point the results are different.
self.assertNotEqual(simple_module(x), simple_module_with_dropout(x))
# Before renaming, loading the state dict is expected to fail.
with self.assertRaisesRegex(RuntimeError, 'Missing key'):
loaded_state_dict = torch.load(state_dict_path)
simple_module_with_dropout.load_state_dict(loaded_state_dict)
# Rename the parameters.
def key_transformation(old_key):
if old_key == "layer.0.weight":
return "layer.1.weight"
return old_key
rename_state_dict_keys(state_dict_path, key_transformation)
# Loading the state dict should succeed now due to the renaming.
loaded_state_dict = torch.load(state_dict_path)
simple_module_with_dropout.load_state_dict(loaded_state_dict)
# Since both modules should have the same parameter values now, the
# results should be equal.
self.assertEqual(simple_module(x), simple_module_with_dropout(x))
# Clean up.
os.remove(state_dict_path)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment