Skip to content

Instantly share code, notes, and snippets.

@hugozanini
Created March 25, 2023 20:44
Show Gist options
  • Save hugozanini/2b4dbe9ff17379bd6595c44c6ebba30c to your computer and use it in GitHub Desktop.
Save hugozanini/2b4dbe9ff17379bd6595c44c6ebba30c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oD3rhSI7gOzB",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "10cbdd68-3453-4f1b-b8ad-fc7bd1df17fc"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/yolov7\n"
]
}
],
"source": [
"# import\n",
"%cd /content/yolov7\n",
"from copy import deepcopy\n",
"from models.yolo import Model\n",
"import torch\n",
"from utils.torch_utils import select_device, is_parallel\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9q4jN2-EgOzC",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1802279e-3ce6-4bae-d2a1-58447b2c438c"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.9/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)\n",
" return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n"
]
}
],
"source": [
"device = select_device('0', batch_size=1)\n",
"# model trained by cfg/training/*.yaml\n",
"ckpt = torch.load('runs/train/yolov7_tiny_stockout/weights/best.pt', map_location=device)\n",
"# reparameterized model in cfg/deploy/*.yaml\n",
"model = Model('cfg/deploy/yolov7-tiny.yaml', ch=3, nc=1).to(device)\n",
"\n",
"with open('cfg/deploy/yolov7-tiny.yaml') as f:\n",
" yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
"anchors = len(yml['anchors'][0]) // 2\n",
"\n",
"# copy intersect weights\n",
"state_dict = ckpt['model'].float().state_dict()\n",
"exclude = []\n",
"intersect_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[k].shape}\n",
"model.load_state_dict(intersect_state_dict, strict=False)\n",
"model.names = ckpt['model'].names\n",
"model.nc = ckpt['model'].nc\n",
"\n",
"# reparametrized YOLOR\n",
"for i in range((model.nc+5)*anchors):\n",
" model.state_dict()['model.77.m.0.weight'].data[i, :, :, :] *= state_dict['model.77.im.0.implicit'].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.77.m.1.weight'].data[i, :, :, :] *= state_dict['model.77.im.1.implicit'].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.77.m.2.weight'].data[i, :, :, :] *= state_dict['model.77.im.2.implicit'].data[:, i, : :].squeeze()\n",
"model.state_dict()['model.77.m.0.bias'].data += state_dict['model.77.m.0.weight'].mul(state_dict['model.77.ia.0.implicit']).sum(1).squeeze()\n",
"model.state_dict()['model.77.m.1.bias'].data += state_dict['model.77.m.1.weight'].mul(state_dict['model.77.ia.1.implicit']).sum(1).squeeze()\n",
"model.state_dict()['model.77.m.2.bias'].data += state_dict['model.77.m.2.weight'].mul(state_dict['model.77.ia.2.implicit']).sum(1).squeeze()\n",
"model.state_dict()['model.77.m.0.bias'].data *= state_dict['model.77.im.0.implicit'].data.squeeze()\n",
"model.state_dict()['model.77.m.1.bias'].data *= state_dict['model.77.im.1.implicit'].data.squeeze()\n",
"model.state_dict()['model.77.m.2.bias'].data *= state_dict['model.77.im.2.implicit'].data.squeeze()\n",
"\n",
"# model to be saved\n",
"ckpt = {'model': deepcopy(model.module if is_parallel(model) else model).half(),\n",
" 'optimizer': None,\n",
" 'training_results': None,\n",
" 'epoch': -1}\n",
"\n",
"# save reparameterized model\n",
"torch.save(ckpt, 'runs/train/yolov7_tiny_stockout/weights/yolov7-tiny.pt')"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment