Last active
May 12, 2024 17:37
-
-
Save the-bass/0bf8aaa302f9ba0d26798b11e4dd73e3 to your computer and use it in GitHub Desktop.
Rename the parameters of a PyTorch module's saved state dict. Last tested with PyTorch 1.0.1.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from collections import OrderedDict | |
def rename_state_dict_keys(source, key_transformation, target=None): | |
""" | |
source -> Path to the saved state dict. | |
key_transformation -> Function that accepts the old key names of the state | |
dict as the only argument and returns the new key name. | |
target (optional) -> Path at which the new state dict should be saved | |
(defaults to `source`) | |
Example: | |
Rename the key `layer.0.weight` `layer.1.weight` and keep the names of all | |
other keys. | |
```py | |
def key_transformation(old_key): | |
if old_key == "layer.0.weight": | |
return "layer.1.weight" | |
return old_key | |
rename_state_dict_keys(state_dict_path, key_transformation) | |
``` | |
""" | |
if target is None: | |
target = source | |
state_dict = torch.load(source) | |
new_state_dict = OrderedDict() | |
for key, value in state_dict.items(): | |
new_key = key_transformation(key) | |
new_state_dict[new_key] = value | |
torch.save(new_state_dict, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import torch | |
import os | |
from rename_state_dict_keys import rename_state_dict_keys | |
class TestRenameStateDictKeys(unittest.TestCase): | |
def test_renaming(self): | |
""" | |
Define some modules for the test. | |
""" | |
import torch.nn as nn | |
from torch_state_control.nn import StatefulModule | |
class SimpleModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.layer = nn.Sequential( | |
nn.Linear(in_features=2, out_features=1, bias=False) | |
) | |
def forward(self, x): | |
return self.layer(x) | |
class SimpleModuleWithDropout(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.layer = nn.Sequential( | |
nn.Dropout(p=0), | |
nn.Linear(in_features=2, out_features=1, bias=False) | |
) | |
def forward(self, x): | |
return self.layer(x) | |
""" | |
Try to load the state dict of SimpleModule into SimpleModuleWithDropout | |
by renaming the parameters for the linear layer. | |
""" | |
state_dict_path = './state_dict.torch' | |
x = torch.tensor([1.0, 10]) | |
simple_module = SimpleModule() | |
simple_module_with_dropout = SimpleModuleWithDropout() | |
torch.save(simple_module.state_dict(), state_dict_path) | |
# The test only works if at this point the results are different. | |
self.assertNotEqual(simple_module(x), simple_module_with_dropout(x)) | |
# Before renaming, loading the state dict is expected to fail. | |
with self.assertRaisesRegex(RuntimeError, 'Missing key'): | |
loaded_state_dict = torch.load(state_dict_path) | |
simple_module_with_dropout.load_state_dict(loaded_state_dict) | |
# Rename the parameters. | |
def key_transformation(old_key): | |
if old_key == "layer.0.weight": | |
return "layer.1.weight" | |
return old_key | |
rename_state_dict_keys(state_dict_path, key_transformation) | |
# Loading the state dict should succeed now due to the renaming. | |
loaded_state_dict = torch.load(state_dict_path) | |
simple_module_with_dropout.load_state_dict(loaded_state_dict) | |
# Since both modules should have the same parameter values now, the | |
# results should be equal. | |
self.assertEqual(simple_module(x), simple_module_with_dropout(x)) | |
# Clean up. | |
os.remove(state_dict_path) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Can we rename the labels also after training the model?
I have a trained model. During training there was some problem with string characters. So i converted my labels into numbers like:
red : 0
blue: 1
green: 2
Now is it possible to rename my label back to actual label names. Hours of training. Would be helpful if anyone has an idea.