Skip to content

Instantly share code, notes, and snippets.

@nulledge
Created October 25, 2018 06:09
Show Gist options
  • Save nulledge/528bf8402e6fc35c2acf8b2661469ece to your computer and use it in GitHub Desktop.
Save nulledge/528bf8402e6fc35c2acf8b2661469ece to your computer and use it in GitHub Desktop.
Torch7 to PyTorch converter
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import skimage\n",
"import skimage.io\n",
"import imageio\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from dotmap import DotMap\n",
"from torch.utils.serialization import load_lua\n",
"from operator import xor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Layer():\n",
" Identity = 0b0000000001\n",
" Convolution = 0b0000000010\n",
" Batch_Norm = 0b0000000100\n",
" ReLU = 0b0000001000\n",
" Sequential = 0b0000010000\n",
" Max_Pool = 0b0000100000\n",
" Add = 0b0001000000\n",
" Nearest_Upsample = 0b0010000000\n",
" Concat = 0b0100000000\n",
" Join = 0b1000000000\n",
" \n",
" @staticmethod\n",
" def to_string(layer):\n",
" return 'Identity' if layer & Layer.Identity else \\\n",
" 'Convolution' if layer & Layer.Convolution else \\\n",
" 'Batch_Norm' if layer & Layer.Batch_Norm else \\\n",
" 'ReLU' if layer & Layer.ReLU else \\\n",
" 'Sequential' if layer & Layer.Sequential else \\\n",
" 'Max_Pool' if layer & Layer.Max_Pool else \\\n",
" 'Add' if layer & Layer.Add else \\\n",
" 'Nearest_Upsample' if layer & Layer.Nearest_Upsample else \\\n",
" 'Concat' if layer & Layer.Concat else \\\n",
" 'Join' if layer & Layer.Join else \\\n",
" None\n",
" \n",
" @staticmethod\n",
" def from_name(name):\n",
" return Layer.Identity if name.startswith('nn.Identity') else \\\n",
" Layer.Convolution if name.startswith('nn.SpatialConvolution') else \\\n",
" Layer.Batch_Norm if name.startswith('nn.SpatialBatchNormalization') else \\\n",
" Layer.ReLU if name.startswith('nn.ReLU') else \\\n",
" Layer.Sequential if name.startswith('nn.Sequential') else \\\n",
" Layer.Max_Pool if name.startswith('nn.SpatialMaxPooling') else \\\n",
" Layer.Add if name.startswith('nn.CAddTable') else \\\n",
" Layer.Nearest_Upsample if name.startswith('nn.SpatialUpSamplingNearest') else \\\n",
" Layer.Concat if name.startswith('torch.legacy.nn.ConcatTable.ConcatTable') else \\\n",
" Layer.Join if name.startswith('nn.JoinTable') else \\\n",
" None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Node:\n",
" def __init__(self, forwardnode, module):\n",
" forwardnode = forwardnode.split('\\n')[0]\n",
" self.id, children = forwardnode.split(';')\n",
" self.children = [word for word in children.split(' ') if word]\n",
" self.data = module\n",
" self.op = Layer.from_name(str(self.data))\n",
" \n",
" assert self.op is not None\n",
" assert self.id.isdigit()\n",
" assert all([child.isdigit() for child in self.children])\n",
" \n",
" def __str__(self):\n",
" return '{node}; {operation}'.format(node=self.id, operation=Layer.to_string(self.op))\n",
" \n",
" @staticmethod\n",
" def _get_param(module):\n",
" op = Layer.from_name(str(module))\n",
" if op & Layer.Convolution:\n",
" param = module.weight, module.bias\n",
" return op, param\n",
" \n",
" elif op & Layer.Batch_Norm:\n",
" param = module.running_mean, module.running_var, module.weight, module.bias, module.momentum\n",
" return op, param\n",
" \n",
" elif op & (Layer.Sequential | Layer.Concat | Layer.Join):\n",
" sub_modules = [Node._get_param(sub_module) for sub_module in module.modules]\n",
" return op, sub_modules\n",
" \n",
" else:\n",
" param = None\n",
" return op, param\n",
" \n",
" \n",
" def get_param(self):\n",
" return Node._get_param(self.data)\n",
" \n",
" \n",
" @staticmethod\n",
" def _copy_to_convolution(source, target):\n",
" op, param = source\n",
" weight, bias = param\n",
"\n",
" assert target.weight.shape == weight.shape\n",
" assert target.bias.shape == bias.shape\n",
"\n",
" target.weight.data = weight\n",
" target.bias.data = bias\n",
" \n",
" \n",
" @staticmethod\n",
" def _copy_to_batch_norm(source, target):\n",
" op, param = source\n",
" running_mean, running_var, weight, bias, momentum = param\n",
"\n",
" assert target.running_mean.shape == running_mean.shape\n",
" assert target.running_var.shape == running_var.shape\n",
" assert target.weight.shape == weight.shape\n",
" assert target.bias.shape == bias.shape\n",
" assert isinstance(target.momentum, float) and isinstance(momentum, float)\n",
"\n",
" target.running_mean = running_mean\n",
" target.running_var = running_var\n",
" target.weight.data = weight\n",
" target.bias.data = bias\n",
" target.momentum = momentum\n",
" \n",
" \n",
" @staticmethod\n",
" def _copy_to_residual(source, target):\n",
" op, sub_modules = source\n",
" op, sub_modules = sub_modules[0]\n",
" op, sub_modules = sub_modules[0]\n",
"\n",
" assert op & Layer.Sequential\n",
"\n",
" for torch_module, pytorch_module in zip(sub_modules, target.resSeq):\n",
" op, _ = torch_module\n",
"\n",
" if op & Layer.Batch_Norm:\n",
" Node._copy_to_batch_norm(torch_module, pytorch_module)\n",
"\n",
" elif op & Layer.ReLU:\n",
" continue\n",
"\n",
" elif op & Layer.Convolution:\n",
" Node._copy_to_convolution(torch_module, pytorch_module)\n",
"\n",
" else:\n",
" raise NotImplementedError()\n",
"\n",
" op, sub_modules = source\n",
" op, sub_modules = sub_modules[0]\n",
" op, param = sub_modules[1]\n",
"\n",
" if op & Layer.Identity:\n",
" pass\n",
"\n",
" elif op & Layer.Sequential:\n",
" sub_modules = param\n",
" op, param = sub_modules[0]\n",
"\n",
" assert op & Layer.Convolution and len(sub_modules) == 1\n",
" \n",
" Node._copy_to_convolution(sub_modules[0], target.conv_skip)\n",
" \n",
" @staticmethod\n",
" def _copy_to(source, target):\n",
" op, param = source\n",
" if op & Layer.Convolution:\n",
" Node._copy_to_convolution(source, target)\n",
" \n",
" elif op & Layer.Batch_Norm:\n",
" Node._copy_to_batch_norm(source, target)\n",
" \n",
" elif op & Layer.Sequential:\n",
" if Node._is_residual(source):\n",
" Node._copy_to_residual(source, target)\n",
" else:\n",
" sub_modules = param\n",
" for torch_module, pytorch_module in zip(sub_modules, target):\n",
" Node._copy_to(torch_module, pytorch_module)\n",
" raise NotImplementedError()\n",
" \n",
" elif op & (Layer.Identity | Layer.ReLU | Layer.Add | Layer.Max_Pool | Layer.Nearest_Upsample):\n",
" pass\n",
" \n",
" else:\n",
" raise NotImplementedError()\n",
" \n",
" def copy_to(self, target):\n",
" Node._copy_to(self.get_param(), target)\n",
" \n",
" @staticmethod\n",
" def _is_residual(source):\n",
" return True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Graph:\n",
" def __init__(self):\n",
" forwardnodes, modules = self.load_data()\n",
"\n",
" self.node = list()\n",
" for forwardnode, module in zip(forwardnodes, modules):\n",
" self.node.append(Node(forwardnode, module))\n",
" \n",
" def load_data(self):\n",
" modules = load_lua('c2f/modules.t7')\n",
" with open('c2f/forwardnodes.txt', 'r') as fd:\n",
" lines = fd.readlines()\n",
" forwardnodes = lines[1:] # The 1st forwardnode is dummy, a input distributor\n",
" return forwardnodes, modules\n",
" \n",
" def find_by_id(self, key):\n",
" for _, node in enumerate(self.node):\n",
" if node.id == key:\n",
" return node\n",
" raise LookupError()\n",
" \n",
" def copy_to_hg(self, first_res_in_torch7, hg):\n",
" # starts from 9\n",
" res_in_torch7 = [\n",
" 0, # 64x64 skip\n",
" 1, # 64x64 skip\n",
" 2, # 64x64 skip\n",
" \n",
" 4, # 32x32 res\n",
" 5, # 32x32 res\n",
" 6, # 32x32 res\n",
" 7, # 32x32 skip\n",
" 8, # 32x32 skip\n",
" 9, # 32x32 skip\n",
" \n",
" 11, # 16x16 res\n",
" 12, # 16x16 res\n",
" 13, # 16x16 res\n",
" 14, # 16x16 skip\n",
" 15, # 16x16 skip\n",
" 16, # 16x16 skip\n",
" \n",
" 18, # 8x8 res\n",
" 19, # 8x8 res\n",
" 20, # 8x8 res\n",
" 21, # 8x8 skip\n",
" 22, # 8x8 skip\n",
" 23, # 8x8 skip\n",
" \n",
" 25, # 4x4 res\n",
" 26, # 4x4 res\n",
" 27, # 4x4 res\n",
" 28, # 4x4 res\n",
" \n",
" 29, # 4x4 res\n",
" 32, # 8x8 res\n",
" 35, # 16x16 res\n",
" 38, # 32x32 res\n",
" ]\n",
" res_in_pytorch = [\n",
" hg.up1,\n",
" hg.up2,\n",
" hg.up4,\n",
" \n",
" hg.low1,\n",
" hg.low2,\n",
" hg.low5,\n",
" hg.low6.up1,\n",
" hg.low6.up2,\n",
" hg.low6.up4,\n",
" \n",
" hg.low6.low1,\n",
" hg.low6.low2,\n",
" hg.low6.low5,\n",
" hg.low6.low6.up1,\n",
" hg.low6.low6.up2,\n",
" hg.low6.low6.up4,\n",
" \n",
" hg.low6.low6.low1,\n",
" hg.low6.low6.low2,\n",
" hg.low6.low6.low5,\n",
" hg.low6.low6.low6.up1,\n",
" hg.low6.low6.low6.up2,\n",
" hg.low6.low6.low6.up4,\n",
" \n",
" hg.low6.low6.low6.low1,\n",
" hg.low6.low6.low6.low2,\n",
" hg.low6.low6.low6.low5,\n",
" hg.low6.low6.low6.low6,\n",
" \n",
" hg.low6.low6.low6.low7,\n",
" hg.low6.low6.low7,\n",
" hg.low6.low7,\n",
" hg.low7,\n",
" ]\n",
" for torch7_idx, pytorch_module in zip(res_in_torch7, res_in_pytorch):\n",
" torch7_module = self.node[first_res_in_torch7 + torch7_idx]\n",
" torch7_module.copy_to(pytorch_module)\n",
" \n",
" def copy_to_intermediate(self, first_conv_in_torch7, l1, l2, out1, out1_, cat1_):\n",
" self.node[first_conv_in_torch7 + 0].copy_to(l1[0]) # Conv\n",
" self.node[first_conv_in_torch7 + 1].copy_to(l1[1]) # Batch-norm\n",
" self.node[first_conv_in_torch7 + 2].copy_to(l1[2]) # ReLU, ll in Newell's\n",
" self.node[first_conv_in_torch7 + 3].copy_to(l2[0]) # Conv\n",
" self.node[first_conv_in_torch7 + 4].copy_to(l2[1]) # Batch-norm\n",
" self.node[first_conv_in_torch7 + 5].copy_to(l2[2]) # ReLU, ll in Newell's\n",
" \n",
" self.node[first_conv_in_torch7 + 6].copy_to(out1) # Conv, tmpOut in Newell's\n",
" \n",
" if out1_ == None and cat1_ == None:\n",
" return\n",
" \n",
" self.node[first_conv_in_torch7 + 9].copy_to(out1_) # Conv, ll_ in Newell's\n",
" self.node[first_conv_in_torch7 + 8].copy_to(cat1_) # Conv, tmpOut_ in Newell's"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph = Graph()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Residual(nn.Module):\n",
" def __init__(self, in_channels, out_channels):\n",
" super(Residual, self).__init__()\n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
" self.conv_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)\n",
"\n",
" self.resSeq = nn.Sequential(\n",
" nn.BatchNorm2d(in_channels),\n",
" nn.ReLU(),\n",
" nn.Conv2d(in_channels, out_channels // 2, kernel_size=1),\n",
" nn.BatchNorm2d(out_channels // 2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=3, stride=1, padding=1),\n",
" nn.BatchNorm2d(out_channels // 2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(out_channels // 2, out_channels, kernel_size=1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" if self.in_channels != self.out_channels:\n",
" skip = self.conv_skip(x)\n",
" else:\n",
" skip = x\n",
"\n",
" return skip + self.resSeq(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Hourglass(nn.Module):\n",
" def __init__(self, n, numIn, numOut):\n",
" super(Hourglass, self).__init__()\n",
" \n",
" self.up1 = Residual(numIn, 256)\n",
" self.up2 = Residual(256, 256)\n",
" self.up4 = Residual(256, numOut)\n",
" \n",
" self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
" \n",
" self.low1 = Residual(numIn, 256)\n",
" self.low2 = Residual(256, 256)\n",
" self.low5 = Residual(256, 256)\n",
" \n",
" if n > 1:\n",
" self.low6 = Hourglass(\n",
" \n",
" n-1, 256, numOut)\n",
" else:\n",
" self.low6 = Residual(256, numOut)\n",
" \n",
" self.low7 = Residual(numOut, numOut)\n",
" self.up5 = nn.UpsamplingNearest2d(scale_factor=2)\n",
" \n",
" def forward(self, inp):\n",
" up1 = self.up1(inp)\n",
" up2 = self.up2(up1)\n",
" up4 = self.up4(up2)\n",
" \n",
" pool = self.max_pool(inp)\n",
" low1 = self.low1(pool)\n",
" low2 = self.low2(low1)\n",
" low5 = self.low5(low2)\n",
" \n",
" low6 = self.low6(low5)\n",
" low7 = self.low7(low6)\n",
" up5 = self.up5(low7)\n",
" \n",
" return up4 + up5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class MainModel(nn.Module):\n",
" def __init__(self, in_channels=3):\n",
" super(MainModel, self).__init__()\n",
"\n",
" \n",
" self.cnv1_ = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3)\n",
" self.cnv1 = nn.Sequential(\n",
" nn.BatchNorm2d(num_features=64),\n",
" nn.ReLU(),\n",
" )\n",
" self.r1 = Residual(in_channels=64, out_channels=128)\n",
" self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
" self.r4 = Residual(128, 128)\n",
" self.r5 = Residual(128, 128)\n",
" self.r6 = Residual(128, 256)\n",
" \n",
" self.outputDim = [1* 17, 2* 17, 4* 17, 64* 17]\n",
" \n",
" self.hg1 = Hourglass(4, 256, 512)\n",
" self.l1 = self.lin(512, 512)\n",
" self.l2 = self.lin(512, 256)\n",
" self.out1 = nn.Conv2d(256, self.outputDim[0], kernel_size=1, stride=1, padding=0)\n",
" self.out1_ = nn.Conv2d(self.outputDim[0], 256+128, kernel_size=1, stride=1, padding=0)\n",
" self.cat1_ = nn.Conv2d(256+128, 256+128, kernel_size=1, stride=1, padding=0)\n",
" \n",
" self.hg2 = Hourglass(4, 256+128, 512)\n",
" self.l3 = self.lin(512, 512)\n",
" self.l4 = self.lin(512, 256)\n",
" self.out2 = nn.Conv2d(256, self.outputDim[1], kernel_size=1, stride=1, padding=0)\n",
" self.out2_ = nn.Conv2d(self.outputDim[1], 256+256, kernel_size=1, stride=1, padding=0)\n",
" self.cat2_ = nn.Conv2d(256+256, 256+256, kernel_size=1, stride=1, padding=0)\n",
" \n",
" self.hg3 = Hourglass(4, 256+256, 512)\n",
" self.l5 = self.lin(512, 512)\n",
" self.l6 = self.lin(512, 256)\n",
" self.out3 = nn.Conv2d(256, self.outputDim[2], kernel_size=1, stride=1, padding=0)\n",
" self.out3_ = nn.Conv2d(self.outputDim[2], 256+256, kernel_size=1, stride=1, padding=0)\n",
" self.cat3_ = nn.Conv2d(256+256, 256+256, kernel_size=1, stride=1, padding=0)\n",
" \n",
" self.hg4 = Hourglass(4, 256+256, 512)\n",
" self.l7 = self.lin(512, 512)\n",
" self.l8 = self.lin(512, 512)\n",
" self.out4 = nn.Conv2d(512, self.outputDim[3], kernel_size=1, stride=1, padding=0)\n",
" \n",
" \n",
" def forward(self, inp):\n",
" cnv1_ = self.cnv1_(inp)\n",
" cnv1 = self.cnv1(cnv1_)\n",
" r1 = self.r1(cnv1)\n",
" pool = self.pool(r1)\n",
" r4 = self.r4(pool)\n",
" r5 = self.r5(r4) \n",
" r6 = self.r6(r5)\n",
" \n",
" hg1 = self.hg1(r6)\n",
" l1 = self.l1(hg1)\n",
" l2 = self.l2(l1)\n",
" out1 = self.out1(l2)\n",
" out1_ = self.out1_(out1)\n",
" cat1 = torch.cat([l2, pool], 1)\n",
" cat1_ = self.cat1_(cat1)\n",
" int1 = cat1_ + out1_\n",
" \n",
" hg2 = self.hg2(int1)\n",
" l3 = self.l3(hg2)\n",
" l4 = self.l4(l3)\n",
" out2 = self.out2(l4)\n",
" out2_ = self.out2_(out2)\n",
" cat2 = torch.cat([l4, l2], 1)\n",
" cat2_ = self.cat2_(cat2)\n",
" int2 = cat2_ + out2_\n",
" \n",
" hg3 = self.hg3(int2)\n",
" l5 = self.l5(hg3)\n",
" l6 = self.l6(l5)\n",
" out3 = self.out3(l6)\n",
" out3_ = self.out3_(out3)\n",
" cat3 = torch.cat([l6, l4], 1)\n",
" cat3_ = self.cat3_(cat3)\n",
" int3 = cat3_ + out3_\n",
" \n",
" hg4 = self.hg4(int3)\n",
" l7 = self.l7(hg4)\n",
" l8 = self.l8(l7)\n",
" out4 = self.out4(l8)\n",
" \n",
" return out4\n",
"\n",
" \n",
" def lin(self, in_channels, out_channels):\n",
" return nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),\n",
" nn.BatchNorm2d(num_features=out_channels),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c2f = MainModel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"feature_extractor = [\n",
" c2f.cnv1_,\n",
" c2f.cnv1[0],\n",
" c2f.cnv1[1],\n",
" c2f.r1,\n",
" c2f.pool,\n",
" c2f.r4,\n",
" c2f.r5,\n",
" c2f.r6,\n",
"]\n",
"for torch7_module, pytorch_module in zip(graph.node[1:8+1], feature_extractor):\n",
" torch7_module.copy_to(pytorch_module)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for src, dst in zip([9, 9+52, 9+52*2, 9+52*3], [c2f.hg1, c2f.hg2, c2f.hg3, c2f.hg4, ]):\n",
" graph.copy_to_hg(first_res_in_torch7=src, hg=dst)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph.copy_to_intermediate(\n",
" first_conv_in_torch7 = 50,\n",
" l1 = c2f.l1,\n",
" l2 = c2f.l2,\n",
" out1 = c2f.out1,\n",
" out1_ = c2f.out1_,\n",
" cat1_ = c2f.cat1_,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph.copy_to_intermediate(\n",
" first_conv_in_torch7 = 50 + 52,\n",
" l1 = c2f.l3,\n",
" l2 = c2f.l4,\n",
" out1 = c2f.out2,\n",
" out1_ = c2f.out2_,\n",
" cat1_ = c2f.cat2_,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph.copy_to_intermediate(\n",
" first_conv_in_torch7 = 50 + 52*2,\n",
" l1 = c2f.l5,\n",
" l2 = c2f.l6,\n",
" out1 = c2f.out3,\n",
" out1_ = c2f.out3_,\n",
" cat1_ = c2f.cat3_,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph.copy_to_intermediate(\n",
" first_conv_in_torch7 = 50 + 52*3,\n",
" l1 = c2f.l7,\n",
" l2 = c2f.l8,\n",
" out1 = c2f.out4,\n",
" out1_ = None,\n",
" cat1_ = None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torchvision import transforms\n",
"\n",
"Color = torch.FloatTensor(\n",
" [[0, 0, 0.5],\n",
" [0, 0, 1],\n",
" [0, 1, 0],\n",
" [1, 1, 0],\n",
" [1, 0, 0]]\n",
")\n",
"\n",
"\n",
"def merge_to_color_heatmap(batch_heatmaps):\n",
" batch, joints, depth, height, width = batch_heatmaps.size()\n",
"\n",
" batch_heatmaps_flat = batch_heatmaps.view(batch, joints, depth, -1)\n",
" max_depth_idx = batch_heatmaps_flat.max(-1)[0].max(-1)[1]\n",
"\n",
" test = list()\n",
" for b_idx in range(batch):\n",
" test2 = list()\n",
" for j_idx in range(joints):\n",
" test2.append(batch_heatmaps[b_idx, j_idx, max_depth_idx[b_idx, j_idx], :, :].view(1, 1, height, width))\n",
"\n",
" test.append(torch.cat(test2, dim=1))\n",
"\n",
" batch_heatmaps = torch.cat(test, dim=0)\n",
" # batch_heatmaps = torch.cat(\n",
" # [torch.cat(\n",
" # [batch_heatmaps[b_idx, j_idx, max_depth_idx[b_idx, j_idx], :, :] for j_idx in range(joints)], dim=2)\n",
" # for b_idx in range(batch)], dim=0)\n",
"\n",
" heatmaps = batch_heatmaps.clamp(0, 1.).view(-1)\n",
"\n",
" frac = torch.div(heatmaps, 0.25)\n",
" lower_indices, upper_indices = torch.floor(frac).long(), torch.ceil(frac).long()\n",
"\n",
" t = frac - torch.floor(frac)\n",
" t = t.view(-1, 1)\n",
"\n",
" k = Color.index_select(0, lower_indices)\n",
" k_1 = Color.index_select(0, upper_indices)\n",
"\n",
" color_heatmap = (1.0 - t) * k + t * k_1\n",
" color_heatmap = color_heatmap.view(batch, joints, height, width, 3)\n",
" color_heatmap = color_heatmap.permute(0, 4, 2, 3, 1) # B3HWC\n",
" color_heatmap, _ = torch.max(color_heatmap, 4) # B3HW\n",
"\n",
" return color_heatmap\n",
"\n",
"\n",
"T = transforms.Compose([\n",
" transforms.ToPILImage(),\n",
" transforms.Resize(256),\n",
" transforms.ToTensor()\n",
"])\n",
"\n",
"\n",
"def get_merged_image(heatmaps, images):\n",
" heatmaps = merge_to_color_heatmap(heatmaps)\n",
" # heatmaps = heatmaps.permute(0, 2, 3, 1) # NHWC\n",
"\n",
" resized_heatmaps = list()\n",
" for idx, ht in enumerate(heatmaps):\n",
" color_ht = T(ht)\n",
" # color_ht = skimage.transform.resize(ht.numpy(), (256, 256), mode='constant')\n",
" resized_heatmaps.append(color_ht)\n",
"\n",
" resized_heatmaps = np.stack(resized_heatmaps, axis=0)\n",
"\n",
" # images = images.transpose(0, 2, 3, 1) * 0.6\n",
" images = images * 0.6\n",
" overlayed_image = np.clip(images + resized_heatmaps * 0.4, 0, 1.)\n",
"\n",
" # overlayed_image = overlayed_image.transpose(0, 3, 1, 2)\n",
"\n",
" return overlayed_image\n",
" # return viz.images(tensor=overlayed_image, nrow=3, win=window)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rgb = np.asarray(skimage.img_as_float(skimage.io.imread('rgb.jpg')))\n",
"rgb = np.expand_dims(rgb.transpose(2, 0, 1), axis=0)\n",
"rgb = torch.Tensor(rgb)\n",
"htmaps = c2f(rgb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output_cpu = htmaps.view(1, 17, 64, 64, 64)\n",
"\n",
"overlay = get_merged_image(output_cpu, rgb.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"imageio.imwrite('c2f.jpg', np.squeeze(overlay).transpose(1, 2, 0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.save(\n",
" c2f.state_dict(),\n",
" 'torch7_c2f.save'\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment