Skip to content

Instantly share code, notes, and snippets.

@nijkah
Last active July 14, 2023 07:43
Show Gist options
  • Save nijkah/5faf6ae356188690f353e3585d9bfc19 to your computer and use it in GitHub Desktop.
Save nijkah/5faf6ae356188690f353e3585d9bfc19 to your computer and use it in GitHub Desktop.
from mmdet.registry import MODELS
from mmdet.utils import register_all_modules
from mmengine.config import Config
import torch
mmdet_config_path = 'configs/rtdetr/rtdetr_r18vd_dec3_8xb2-72e_coco.py'
model_weight_file = 'rtdetr_r50vd_6x_coco_mmdet.pth'
"""
# convert pdparams to torch tensors
import torch
import paddle
files = '''
rtdetr_hgnetv2_l_6x_coco.pdparams
rtdetr_hgnetv2_x_6x_coco.pdparams
rtdetr_r101vd_6x_coco.pdparams
rtdetr_r18vd_dec3_6x_coco.pdparams
rtdetr_r34vd_dec4_6x_coco.pdparams
rtdetr_r50vd_6x_coco.pdparams
rtdetr_r50vd_m_6x_coco.pdparams
'''
files = files.split()
for f in files:
state_dict = paddle.load(f)
new_dict = dict()
for k, v in state_dict.items():
new_dict[k] = torch.tensor(torch.tensor(v.numpy()))
torch.save(new_dict, f.replace('pdparams', 'pth'))
"""
ppdet_weight_path = 'rtdetr_r18vd_dec3_6x_coco.pth'
cfg = Config.fromfile(model_config_path)
register_all_modules()
model = MODELS.build(cfg.model)
model_keys = [k for k in model.state_dict().keys() if 'num_batches_tracked' not in k]
pretrained_weight = torch.load(ppdet_weight_path)
pre_keys = [k for k in pretrained_weight.keys()]
# convert pretrained weight to mmdet format
# NOTE: the order of the keys in the pretrained weight is different from the
# order of the keys in the model.state_dict()
new_weight = dict(state_dict={})
backbone_conv_counts = 0
norm_counts = 0
prev_weight = False
alphabet = 'abcdefghijklmnopqrstuvwxyz'
for k in sorted(pretrained_weight.keys()):
origin_k = k
v = pretrained_weight[k]
if '_mean' in k:
k = k.replace('_mean', 'running_mean')
if '_variance' in k:
k = k.replace('_variance', 'running_var')
if 'backbone' in k:
prefix = 'backbone.'
subname = k.split('backbone.')[-1]
if subname.startswith('conv1.conv1_'):
if prev_weight and 'weight' in k:
prev_weight = False
if backbone_conv_counts in [2, 5]:
backbone_conv_counts += 1
if subname.startswith('conv1.conv1_'):
subname = subname.replace('conv1.conv1_', 'stem.', 1)
if 'conv.weight' in subname:
subname = subname.replace('conv.weight', 'weight', 1)
if 'running_mean' in subname:
norm_counts += 1
if 'running_var' in subname:
norm_counts += 1
if 'norm' in subname:
subname = subname.replace('norm.', '', 1)
conv_id = str(int(subname.split('.')[1])+1 + backbone_conv_counts)
# subname = '.'.join([subname.split('.')[0], conv_id , *subname.split('.')[2:]])
subname = '.'.join([subname.split('.')[0], str(backbone_conv_counts), *subname.split('.')[2:]])
if 'weight' in subname:
prev_weight = True
backbone_conv_counts += 1
k = prefix+subname
if 'res' in k:
k = k.split('.')
k = '.'.join([k[0], *k[2:]])
k = k.replace('res', 'layer')
layer_id = k.split('.')[1]
num = str(int(layer_id[-2])-1)+'.'
letter = layer_id[-1]
alphabet_id = alphabet.index(letter)
layer_id = 'layer'+num+str(alphabet_id)
k = '.'.join([k.split('.')[0], layer_id, *k.split('.')[2:]])
if 'short' in k:
k = k.replace('short', 'downsample')
k = k.replace('conv', '1')
k = k.replace('norm', '2')
# if 'backbone.layer1.0.downsample.1.weight'
if 'layer1.0.downsample' not in k:
k = k.replace('downsample.1', 'downsample')
if 'layer1.0.downsample' in k:
llayer_id = k.split('.')[-2]
num = str(int(llayer_id)-1)
k = '.'.join([*k.split('.')[:-2], num, k.split('.')[-1]])
# k = k.replace('weight', '0.weight')
# if 'layer1.0.downsample.2' in k and 'running' in k:
# k = k.replace('layer1.0.downsample.2', 'layer1.0.downsample.1')
else:
k = k.replace('branch2', '')
sublayer_id, sublayer_type, name = k.split('.')[-3:]
new_sublayer_id = sublayer_type + str(alphabet.index(sublayer_id)+1)
k = '.'.join([*k.split('.')[:3], new_sublayer_id, name])
k = k.replace('norm', 'bn')
if 'neck' in k:
if 'encoder' in k:
k = k.replace('self_attn', 'self_attn.attn')
k = k.replace('linear1', 'ffn.layers.0.0')
k = k.replace('linear2', 'ffn.layers.1')
k = k.replace('norm1', 'norms.0')
k = k.replace('norm2', 'norms.1')
else:
if 'bottlenecks' in k:
k = k.replace('conv1', 'branch_3x3')
k = k.replace('conv2', 'branch_1x1')
k = k.replace('bn', 'norm')
if 'input_proj' in k:
k = k.replace('0.0', '0.conv')
k = k.replace('0.1', '0.bn')
k = k.replace('r.0', '1.conv')
k = k.replace('1.1', '1.bn')
k = k.replace('0.weight', 'conv.weight')
k = k.replace('2.1', '2.bn')
if 'downsample' in k:
pass
if 'transformer' in k:
k = k.replace('transformer.input_proj', 'neck.projector.convs')
if 'projector' in k:
k = k.replace('norm', 'bn')
k = k.replace('transformer.decoder', 'decoder')
k = k.replace('self_attn', 'self_attn.attn')
k = k.replace('transformer.denoising_class_embed.weight', 'dn_query_generator.label_embedding.weight')
k = k.replace('transformer.query_pos_head', 'decoder.ref_point_head')
k = k.replace('transformer.enc_output.0.weight', 'memory_trans_fc.weight')
k = k.replace('transformer.enc_output.0.bias', 'memory_trans_fc.bias')
k = k.replace('transformer.enc_output.1.weight', 'memory_trans_norm.weight')
k = k.replace('transformer.enc_output.1.bias', 'memory_trans_norm.bias')
k = k.replace('linear1.bias', 'ffn.layers.0.0.bias')
k = k.replace('linear1.weight', 'ffn.layers.0.0.weight')
k = k.replace('linear2.bias', 'ffn.layers.1.bias')
k = k.replace('linear2.weight', 'ffn.layers.1.weight')
if 'decoder' in k and 'norm' in k:
norm_id = k.split('.')[-2]
num_id = str(int(norm_id[-1])-1)
k = '.'.join([*k.split('.')[:-2], 'norms.'+num_id, *k.split('.')[-1:]])
if 'head' in k:
k = k.replace('transformer.dec_bbox_head', 'bbox_head.reg_branches')
k = k.replace('transformer.dec_score_head', 'bbox_head.cls_branches')
k = k.replace('transformer.enc_bbox_head.layers', 'bbox_head.reg_branches.3')
k = k.replace('transformer.enc_score_head', 'bbox_head.cls_branches.3')
if 'reg_branches' in k:
k = k.replace('layers.','')
layer_id = str(int(k.split('.')[-2])*2)
k = '.'.join([*k.split('.')[:-2], layer_id, *k.split('.')[-1:]])
if v.dim() == 2 and 'label_embedding' not in k:
v = v.transpose(0, 1)
if k not in new_weight['state_dict']:
new_weight['state_dict'][k] = v
else:
print('duplicated!', k)
model.load_state_dict(new_weight['state_dict'])
torch.save(new_weight, model_weight_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment