Skip to content

Instantly share code, notes, and snippets.

@the-bass
Last active January 1, 2024 08:28
Show Gist options
  • Star 24 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save the-bass/0bf8aaa302f9ba0d26798b11e4dd73e3 to your computer and use it in GitHub Desktop.
Save the-bass/0bf8aaa302f9ba0d26798b11e4dd73e3 to your computer and use it in GitHub Desktop.
Rename the parameters of a PyTorch module's saved state dict. Last tested with PyTorch 1.0.1.
import torch
from collections import OrderedDict
def rename_state_dict_keys(source, key_transformation, target=None):
"""
source -> Path to the saved state dict.
key_transformation -> Function that accepts the old key names of the state
dict as the only argument and returns the new key name.
target (optional) -> Path at which the new state dict should be saved
(defaults to `source`)
Example:
Rename the key `layer.0.weight` `layer.1.weight` and keep the names of all
other keys.
```py
def key_transformation(old_key):
if old_key == "layer.0.weight":
return "layer.1.weight"
return old_key
rename_state_dict_keys(state_dict_path, key_transformation)
```
"""
if target is None:
target = source
state_dict = torch.load(source)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
new_key = key_transformation(key)
new_state_dict[new_key] = value
torch.save(new_state_dict, target)
import unittest
import torch
import os
from rename_state_dict_keys import rename_state_dict_keys
class TestRenameStateDictKeys(unittest.TestCase):
def test_renaming(self):
"""
Define some modules for the test.
"""
import torch.nn as nn
from torch_state_control.nn import StatefulModule
class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_features=2, out_features=1, bias=False)
)
def forward(self, x):
return self.layer(x)
class SimpleModuleWithDropout(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Dropout(p=0),
nn.Linear(in_features=2, out_features=1, bias=False)
)
def forward(self, x):
return self.layer(x)
"""
Try to load the state dict of SimpleModule into SimpleModuleWithDropout
by renaming the parameters for the linear layer.
"""
state_dict_path = './state_dict.torch'
x = torch.tensor([1.0, 10])
simple_module = SimpleModule()
simple_module_with_dropout = SimpleModuleWithDropout()
torch.save(simple_module.state_dict(), state_dict_path)
# The test only works if at this point the results are different.
self.assertNotEqual(simple_module(x), simple_module_with_dropout(x))
# Before renaming, loading the state dict is expected to fail.
with self.assertRaisesRegex(RuntimeError, 'Missing key'):
loaded_state_dict = torch.load(state_dict_path)
simple_module_with_dropout.load_state_dict(loaded_state_dict)
# Rename the parameters.
def key_transformation(old_key):
if old_key == "layer.0.weight":
return "layer.1.weight"
return old_key
rename_state_dict_keys(state_dict_path, key_transformation)
# Loading the state dict should succeed now due to the renaming.
loaded_state_dict = torch.load(state_dict_path)
simple_module_with_dropout.load_state_dict(loaded_state_dict)
# Since both modules should have the same parameter values now, the
# results should be equal.
self.assertEqual(simple_module(x), simple_module_with_dropout(x))
# Clean up.
os.remove(state_dict_path)
if __name__ == '__main__':
unittest.main()
@Praveenk8051
Copy link

Can we rename the labels also after training the model?

I have a trained model. During training there was some problem with string characters. So i converted my labels into numbers like:

red : 0

blue: 1

green: 2

Now is it possible to rename my label back to actual label names. Hours of training. Would be helpful if anyone has an idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment