Last active
November 17, 2019 18:18
-
-
Save ProGamerGov/1bc833a8ae91f81e7e40037d052f8193 to your computer and use it in GitHub Desktop.
Make a model's weight and bias names be compatible with neural-style-pt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from collections import OrderedDict | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-model_file", type=str, default='') | |
parser.add_argument("-output_name", type=str, default='') | |
params = parser.parse_args() | |
def main(): | |
state_dict = torch.load(params.model_file) | |
new_state_dict = rename_items(state_dict, get_net_type(state_dict)) | |
if '.pth' not in params.output_name: | |
params.output_name = params.output_name + '.pth' | |
torch.save(new_state_dict, params.output_name) | |
def get_net_type(state_dict): | |
# VGG16 | |
if len(state_dict.items()) == 32: | |
layer_nums = ['0', '2', '5', '7', '10', '12', '14', '17', '19', '21', '24', '26', '28', '0c', '3c', '6c'] | |
elif len(state_dict.items()) == 30: | |
layer_nums = ['0', '2', '5', '7', '10', '12', '14', '17', '19', '21', '24', '26', '28', '0c', '3c'] | |
# VGG19 | |
elif len(state_dict.items()) == 38: | |
layer_nums = ['0', '2', '5', '7', '10', '12', '14', '16', '19', '21', '23', '25', '28', '30', '32', '34', '0c', '3c', '6c'] | |
elif len(state_dict.items()) == 36: | |
layer_nums = ['0', '2', '5', '7', '10', '12', '14', '16', '19', '21', '23', '25', '28', '30', '32', '34', '0c', '3c'] | |
# NIN | |
elif len(state_dict.items()) == 24: | |
layer_nums = ['0', '2', '4', '7', '9', '11', '14', '16', '18', '22', '24', '26'] | |
layer_nums = [y for x in layer_nums for y in (x,)*2] | |
return layer_nums | |
def rename_items(state_dict, layer_nums): | |
i = 0 | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
n = layer_nums[i] | |
if len(list(v.size())) == 1 and 'c' not in n: | |
n = 'features.' + n + '.bias' | |
elif 'c' not in n: | |
n = 'features.' + n + '.weight' | |
elif len(list(v.size())) == 1 and 'c' in n: | |
n = 'classifier.' + n.replace('c', '') + '.bias' | |
elif 'c' in n: | |
n = 'classifier.' + n.replace('c', '') + '.weight' | |
i +=1 | |
new_state_dict[n] = v | |
print(k + ' --> ' + n) | |
return new_state_dict | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment