Skip to content

Instantly share code, notes, and snippets.

@TransparentLC
Last active March 16, 2023 06:27
Show Gist options
  • Save TransparentLC/5e845b3a668f252e4da774906b000fb8 to your computer and use it in GitHub Desktop.
Save TransparentLC/5e845b3a668f252e4da774906b000fb8 to your computer and use it in GitHub Desktop.
import argparse
import collections
import os
import re
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
# https://github.com/xinntao/Real-ESRGAN/blob/master/scripts/pytorch2onnx.py
def main(args):
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
if args.params:
keyname = 'params'
else:
keyname = 'params_ema'
state = torch.load(args.input, map_location=torch.device('cpu'))
try:
model.load_state_dict(state[keyname])
except KeyError:
stateConv = collections.OrderedDict()
stateConv['conv_first.weight'] = state['model.0.weight']
stateConv['conv_first.bias'] = state['model.0.bias']
# body.0.rdb1.conv1.weight <- model.1.sub.0.RDB1.conv1.0.weight
# body.22.rdb3.conv5.bias <- model.1.sub.22.RDB3.conv5.0.bias
for k, v in state.items():
if m := re.search(r'model\.1\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)', k):
stateConv[f'body.{m.group(1)}.rdb{m.group(2)}.conv{m.group(3)}.{m.group(4)}'] = v
for a, b in zip((
'conv_body.weight',
'conv_body.bias',
'conv_up1.weight',
'conv_up1.bias',
'conv_up2.weight',
'conv_up2.bias',
'conv_hr.weight',
'conv_hr.bias',
'conv_last.weight',
'conv_last.bias',
), tuple(state.keys())[-10:]):
stateConv[a] = state[b]
model.load_state_dict(stateConv)
# set the train mode to false since we will only run the forward pass.
model.train(False)
model.cpu().eval()
x = torch.rand(1, 3, 64, 64)
with torch.no_grad():
torch.onnx.export(model, x, args.output or (os.path.splitext(args.input)[0] + '.onnx'), opset_version=11, export_params=True)
if __name__ == '__main__':
"""Convert pytorch model to onnx models"""
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, help='Input model path')
parser.add_argument('--output', type=str, help='Output onnx path')
parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment