Created
March 25, 2023 20:44
-
-
Save hugozanini/2b4dbe9ff17379bd6595c44c6ebba30c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "oD3rhSI7gOzB", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "10cbdd68-3453-4f1b-b8ad-fc7bd1df17fc" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content/yolov7\n" | |
] | |
} | |
], | |
"source": [ | |
"# import\n", | |
"%cd /content/yolov7\n", | |
"from copy import deepcopy\n", | |
"from models.yolo import Model\n", | |
"import torch\n", | |
"from utils.torch_utils import select_device, is_parallel\n", | |
"import yaml" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "9q4jN2-EgOzC", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "1802279e-3ce6-4bae-d2a1-58447b2c438c" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.9/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)\n", | |
" return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" | |
] | |
} | |
], | |
"source": [ | |
"device = select_device('0', batch_size=1)\n", | |
"# model trained by cfg/training/*.yaml\n", | |
"ckpt = torch.load('runs/train/yolov7_tiny_stockout/weights/best.pt', map_location=device)\n", | |
"# reparameterized model in cfg/deploy/*.yaml\n", | |
"model = Model('cfg/deploy/yolov7-tiny.yaml', ch=3, nc=1).to(device)\n", | |
"\n", | |
"with open('cfg/deploy/yolov7-tiny.yaml') as f:\n", | |
" yml = yaml.load(f, Loader=yaml.SafeLoader)\n", | |
"anchors = len(yml['anchors'][0]) // 2\n", | |
"\n", | |
"# copy intersect weights\n", | |
"state_dict = ckpt['model'].float().state_dict()\n", | |
"exclude = []\n", | |
"intersect_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[k].shape}\n", | |
"model.load_state_dict(intersect_state_dict, strict=False)\n", | |
"model.names = ckpt['model'].names\n", | |
"model.nc = ckpt['model'].nc\n", | |
"\n", | |
"# reparametrized YOLOR\n", | |
"for i in range((model.nc+5)*anchors):\n", | |
" model.state_dict()['model.77.m.0.weight'].data[i, :, :, :] *= state_dict['model.77.im.0.implicit'].data[:, i, : :].squeeze()\n", | |
" model.state_dict()['model.77.m.1.weight'].data[i, :, :, :] *= state_dict['model.77.im.1.implicit'].data[:, i, : :].squeeze()\n", | |
" model.state_dict()['model.77.m.2.weight'].data[i, :, :, :] *= state_dict['model.77.im.2.implicit'].data[:, i, : :].squeeze()\n", | |
"model.state_dict()['model.77.m.0.bias'].data += state_dict['model.77.m.0.weight'].mul(state_dict['model.77.ia.0.implicit']).sum(1).squeeze()\n", | |
"model.state_dict()['model.77.m.1.bias'].data += state_dict['model.77.m.1.weight'].mul(state_dict['model.77.ia.1.implicit']).sum(1).squeeze()\n", | |
"model.state_dict()['model.77.m.2.bias'].data += state_dict['model.77.m.2.weight'].mul(state_dict['model.77.ia.2.implicit']).sum(1).squeeze()\n", | |
"model.state_dict()['model.77.m.0.bias'].data *= state_dict['model.77.im.0.implicit'].data.squeeze()\n", | |
"model.state_dict()['model.77.m.1.bias'].data *= state_dict['model.77.im.1.implicit'].data.squeeze()\n", | |
"model.state_dict()['model.77.m.2.bias'].data *= state_dict['model.77.im.2.implicit'].data.squeeze()\n", | |
"\n", | |
"# model to be saved\n", | |
"ckpt = {'model': deepcopy(model.module if is_parallel(model) else model).half(),\n", | |
" 'optimizer': None,\n", | |
" 'training_results': None,\n", | |
" 'epoch': -1}\n", | |
"\n", | |
"# save reparameterized model\n", | |
"torch.save(ckpt, 'runs/train/yolov7_tiny_stockout/weights/yolov7-tiny.pt')" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment