Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active November 17, 2019 18:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ProGamerGov/1bc833a8ae91f81e7e40037d052f8193 to your computer and use it in GitHub Desktop.
Save ProGamerGov/1bc833a8ae91f81e7e40037d052f8193 to your computer and use it in GitHub Desktop.
Make a model's weight and bias names be compatible with neural-style-pt
import torch
from collections import OrderedDict
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-model_file", type=str, default='')
parser.add_argument("-output_name", type=str, default='')
params = parser.parse_args()
def main():
state_dict = torch.load(params.model_file)
new_state_dict = rename_items(state_dict, get_net_type(state_dict))
if '.pth' not in params.output_name:
params.output_name = params.output_name + '.pth'
torch.save(new_state_dict, params.output_name)
def get_net_type(state_dict):
# VGG16
if len(state_dict.items()) == 32:
layer_nums = ['0', '2', '5', '7', '10', '12', '14', '17', '19', '21', '24', '26', '28', '0c', '3c', '6c']
elif len(state_dict.items()) == 30:
layer_nums = ['0', '2', '5', '7', '10', '12', '14', '17', '19', '21', '24', '26', '28', '0c', '3c']
# VGG19
elif len(state_dict.items()) == 38:
layer_nums = ['0', '2', '5', '7', '10', '12', '14', '16', '19', '21', '23', '25', '28', '30', '32', '34', '0c', '3c', '6c']
elif len(state_dict.items()) == 36:
layer_nums = ['0', '2', '5', '7', '10', '12', '14', '16', '19', '21', '23', '25', '28', '30', '32', '34', '0c', '3c']
# NIN
elif len(state_dict.items()) == 24:
layer_nums = ['0', '2', '4', '7', '9', '11', '14', '16', '18', '22', '24', '26']
layer_nums = [y for x in layer_nums for y in (x,)*2]
return layer_nums
def rename_items(state_dict, layer_nums):
i = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
n = layer_nums[i]
if len(list(v.size())) == 1 and 'c' not in n:
n = 'features.' + n + '.bias'
elif 'c' not in n:
n = 'features.' + n + '.weight'
elif len(list(v.size())) == 1 and 'c' in n:
n = 'classifier.' + n.replace('c', '') + '.bias'
elif 'c' in n:
n = 'classifier.' + n.replace('c', '') + '.weight'
i +=1
new_state_dict[n] = v
print(k + ' --> ' + n)
return new_state_dict
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment