-
-
Save TransparentLC/5e845b3a668f252e4da774906b000fb8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import collections | |
import os | |
import re | |
import torch | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
# https://github.com/xinntao/Real-ESRGAN/blob/master/scripts/pytorch2onnx.py | |
def main(args): | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
if args.params: | |
keyname = 'params' | |
else: | |
keyname = 'params_ema' | |
state = torch.load(args.input, map_location=torch.device('cpu')) | |
try: | |
model.load_state_dict(state[keyname]) | |
except KeyError: | |
stateConv = collections.OrderedDict() | |
stateConv['conv_first.weight'] = state['model.0.weight'] | |
stateConv['conv_first.bias'] = state['model.0.bias'] | |
# body.0.rdb1.conv1.weight <- model.1.sub.0.RDB1.conv1.0.weight | |
# body.22.rdb3.conv5.bias <- model.1.sub.22.RDB3.conv5.0.bias | |
for k, v in state.items(): | |
if m := re.search(r'model\.1\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)', k): | |
stateConv[f'body.{m.group(1)}.rdb{m.group(2)}.conv{m.group(3)}.{m.group(4)}'] = v | |
for a, b in zip(( | |
'conv_body.weight', | |
'conv_body.bias', | |
'conv_up1.weight', | |
'conv_up1.bias', | |
'conv_up2.weight', | |
'conv_up2.bias', | |
'conv_hr.weight', | |
'conv_hr.bias', | |
'conv_last.weight', | |
'conv_last.bias', | |
), tuple(state.keys())[-10:]): | |
stateConv[a] = state[b] | |
model.load_state_dict(stateConv) | |
# set the train mode to false since we will only run the forward pass. | |
model.train(False) | |
model.cpu().eval() | |
x = torch.rand(1, 3, 64, 64) | |
with torch.no_grad(): | |
torch.onnx.export(model, x, args.output or (os.path.splitext(args.input)[0] + '.onnx'), opset_version=11, export_params=True) | |
if __name__ == '__main__': | |
"""Convert pytorch model to onnx models""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', type=str, help='Input model path') | |
parser.add_argument('--output', type=str, help='Output onnx path') | |
parser.add_argument('--params', action='store_false', help='Use params instead of params_ema') | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment