Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created October 2, 2021 12:23
Show Gist options
  • Save Lyken17/60dec0b4c64db6f85918179eca35f432 to your computer and use it in GitHub Desktop.
Save Lyken17/60dec0b4c64db6f85918179eca35f432 to your computer and use it in GitHub Desktop.
Use torch.fx to count FLOPs
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 266,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from torch.fx import symbolic_trace"
]
},
{
"cell_type": "code",
"execution_count": 277,
"metadata": {},
"outputs": [],
"source": [
"# Simple module for demonstration\n",
"\n",
"class MyOP(nn.Module):\n",
" def forward(self, input):\n",
" return input - 1\n",
"\n",
"class MyModule(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.param = torch.nn.Parameter(torch.rand(3, 4))\n",
" self.linear = torch.nn.Linear(4, 5)\n",
" self.linear2 = torch.nn.Linear(5, 5)\n",
" self.myop = MyOP()\n",
"\n",
" def forward(self, x):\n",
" out = self.linear(x + self.param)\n",
" out = out ** 2\n",
" out = self.myop(out)\n",
" return self.linear2(x).clamp(min=0.0, max=1.0)\n",
"\n",
"\n",
"module = MyModule()\n",
"# Symbolic tracing frontend - captures the semantics of the module\n",
"symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)"
]
},
{
"cell_type": "code",
"execution_count": 278,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"graph():\n",
" %x : [#users=2] = placeholder[target=x]\n",
" %param : [#users=1] = get_attr[target=param]\n",
" %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})\n",
" %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})\n",
" %pow_1 : [#users=1] = call_function[target=operator.pow](args = (%linear, 2), kwargs = {})\n",
" %sub : [#users=0] = call_function[target=operator.sub](args = (%pow_1, 1), kwargs = {})\n",
" %linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {})\n",
" %clamp : [#users=1] = call_method[target=clamp](args = (%linear2,), kwargs = {min: 0.0, max: 1.0})\n",
" return clamp\n"
]
}
],
"source": [
"print(symbolic_traced.graph)"
]
},
{
"cell_type": "code",
"execution_count": 279,
"metadata": {},
"outputs": [],
"source": [
"g = symbolic_traced.graph"
]
},
{
"cell_type": "code",
"execution_count": 280,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"opcode name target args kwargs\n",
"------------- ------- ----------------------- ----------- ------------------------\n",
"placeholder x x () {}\n",
"get_attr param param () {}\n",
"call_function add <built-in function add> (x, param) {}\n",
"call_module linear linear (add,) {}\n",
"call_function pow_1 <built-in function pow> (linear, 2) {}\n",
"call_function sub <built-in function sub> (pow_1, 1) {}\n",
"call_module linear2 linear2 (x,) {}\n",
"call_method clamp clamp (linear2,) {'min': 0.0, 'max': 1.0}\n",
"output output output (clamp,) {}\n"
]
}
],
"source": [
"g.print_tabular()"
]
},
{
"cell_type": "code",
"execution_count": 281,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"call_module linear (add,)\n",
"call_module linear2 (x,)\n",
"call_method clamp (linear2,)\n"
]
}
],
"source": [
"for node in g.nodes:\n",
" if node.op in [\"call_module\", \"call_method\"]:\n",
" print(node.op, node.target, node.args)"
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"def forward(self, x):\n",
" param = self.param\n",
" add = x + param; param = None\n",
" linear = self.linear(add); add = None\n",
" linear2 = self.linear2(x); x = None\n",
" clamp = linear2.clamp(min = 0.0, max = 1.0); linear2 = None\n",
" return clamp\n",
" \n"
]
}
],
"source": [
"print(symbolic_traced.code)"
]
},
{
"cell_type": "code",
"execution_count": 200,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl"
]
},
"execution_count": 200,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(symbolic_traced)"
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {},
"outputs": [],
"source": [
"from torch.fx.passes.shape_prop import ShapeProp"
]
},
{
"cell_type": "code",
"execution_count": 282,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TwoLayerNet(\n",
" (linear1): Linear(in_features=1000, out_features=100, bias=True)\n",
" (linear2): Linear(in_features=100, out_features=10, bias=True)\n",
")\n"
]
}
],
"source": [
"class TwoLayerNet(torch.nn.Module):\n",
" def __init__(self, D_in, H, D_out):\n",
" super(TwoLayerNet, self).__init__()\n",
" self.linear1 = torch.nn.Linear(D_in, H)\n",
" self.linear2 = torch.nn.Linear(H, D_out)\n",
" def forward(self, x):\n",
" h_relu = self.linear1(x).clamp(min=0)\n",
" y_pred = self.linear2(h_relu).clamp(min=0)\n",
" \n",
" return y_pred * y_pred\n",
"\n",
"N, D_in, H, D_out = 64, 1000, 100, 10\n",
"x = torch.randn(N, D_in)\n",
"y = torch.randn(N, D_out)\n",
"net = TwoLayerNet(D_in, H, D_out)\n",
"gm = torch.fx.symbolic_trace(net)\n",
"sample_input = torch.randn(50, D_in)\n",
"ShapeProp(gm).propagate(sample_input);\n",
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 283,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x,\tplaceholder,\t()\n",
"input_shape:\t\n",
"output_shape:\ttorch.Size([50, 1000])\n",
"========================================\n",
"linear1,\tcall_module,\t(x,)\n",
"weight_shape: torch.Size([100, 1000])\n",
"input_shape:\ttorch.Size([50, 1000])\t\n",
"output_shape:\ttorch.Size([50, 100])\n",
"========================================\n",
"clamp,\tcall_method,\t(linear1,)\n",
"input_shape:\ttorch.Size([50, 100])\t\n",
"output_shape:\ttorch.Size([50, 100])\n",
"========================================\n",
"linear2,\tcall_module,\t(clamp,)\n",
"weight_shape: torch.Size([10, 100])\n",
"input_shape:\ttorch.Size([50, 100])\t\n",
"output_shape:\ttorch.Size([50, 10])\n",
"========================================\n",
"clamp,\tcall_method,\t(linear2,)\n",
"input_shape:\ttorch.Size([50, 10])\t\n",
"output_shape:\ttorch.Size([50, 10])\n",
"========================================\n",
"<built-in function mul>,\tcall_function,\t(clamp_1, clamp_1)\n",
"input_shape:\t\n",
"output_shape:\ttorch.Size([50, 10])\n",
"========================================\n",
"output,\toutput,\t(mul,)\n",
"input_shape:\t\n",
"output_shape:\ttorch.Size([50, 10])\n",
"========================================\n"
]
}
],
"source": [
"v_maps = {}\n",
"\n",
"for node in gm.graph.nodes:\n",
" # print(f\"{node.target},\\t{node.op},\\t{node.meta['tensor_meta'].dtype},\\t{node.meta['tensor_meta'].shape}\")\n",
" print(f\"{node.target},\\t{node.op},\\t{node.args}\")\n",
" node_op_type = str(node.target).split(\".\")[-1]\n",
" \n",
" if node.op == \"call_function\":\n",
" pass\n",
" elif node.op == \"call_method\":\n",
" pass\n",
" elif node.op == \"call_module\":\n",
" if node_op_type not in [\"relu\", \"maxpool\", \"avgpool\"]:\n",
" print(f\"weight_shape: {net.state_dict()[node.target + '.weight'].shape}\")\n",
" else:\n",
" print(f\"weight_shape: None\")\n",
" print(\"input_shape:\", end=\"\\t\")\n",
" for arg in node.args:\n",
" if str(arg) not in v_maps:\n",
" continue\n",
" print(f\"{v_maps[str(arg)]}\", end=\"\\t\")\n",
" print()\n",
" print(f\"output_shape:\\t{node.meta['tensor_meta'].shape}\")\n",
" v_maps[str(node.target)] = node.meta['tensor_meta'].shape\n",
" print(\"==\" * 20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 284,
"metadata": {},
"outputs": [],
"source": [
"def count_clamp(input_shapes, output_shapes):\n",
" return 0\n",
"\n",
"def count_mul(input_shapes, output_shapes):\n",
" # element-wise\n",
" return output_shapes[0].numel()\n",
"\n",
"def count_nn_linear(input_shapes, output_shapes):\n",
" in_shape = input_shapes[0]\n",
" out_shape = output_shapes[0]\n",
" in_features = in_shape[-1]\n",
" num_elements = out_shape.numel()\n",
" return in_features * num_elements\n",
"\n",
"count_map = {\n",
" nn.Linear: count_nn_linear,\n",
" \"clamp\": count_clamp,\n",
" \"<built-in function mul>\": count_mul,\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 285,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NodeOP:placeholder,\tTarget:x,\tNodeName:x,\tNodeArgs:()\n",
"input_shape:\t\n",
"output_shape:\ttorch.Size([50, 1000])\n",
"NodeFlops: 0\n",
"========================================\n",
"NodeOP:call_module,\tTarget:linear1,\tNodeName:linear1,\tNodeArgs:(x,)\n",
"input_shape:\ttorch.Size([50, 1000])\t\n",
"output_shape:\ttorch.Size([50, 100])\n",
"<class 'torch.nn.modules.linear.Linear'> True\n",
"weight_shape: torch.Size([100, 1000])\n",
"NodeFlops: 5000000\n",
"========================================\n",
"NodeOP:call_method,\tTarget:clamp,\tNodeName:clamp,\tNodeArgs:(linear1,)\n",
"input_shape:\ttorch.Size([50, 100])\t\n",
"output_shape:\ttorch.Size([50, 100])\n",
"NodeFlops: 0\n",
"========================================\n",
"NodeOP:call_module,\tTarget:linear2,\tNodeName:linear2,\tNodeArgs:(clamp,)\n",
"input_shape:\ttorch.Size([50, 100])\t\n",
"output_shape:\ttorch.Size([50, 10])\n",
"<class 'torch.nn.modules.linear.Linear'> True\n",
"weight_shape: torch.Size([10, 100])\n",
"NodeFlops: 50000\n",
"========================================\n",
"NodeOP:call_method,\tTarget:clamp,\tNodeName:clamp_1,\tNodeArgs:(linear2,)\n",
"input_shape:\ttorch.Size([50, 10])\t\n",
"output_shape:\ttorch.Size([50, 10])\n",
"NodeFlops: 0\n",
"========================================\n",
"NodeOP:call_function,\tTarget:<built-in function mul>,\tNodeName:mul,\tNodeArgs:(clamp_1, clamp_1)\n",
"input_shape:\ttorch.Size([50, 10])\ttorch.Size([50, 10])\t\n",
"output_shape:\ttorch.Size([50, 10])\n",
"NodeFlops: 500\n",
"========================================\n",
"NodeOP:output,\tTarget:output,\tNodeName:output,\tNodeArgs:(mul,)\n",
"input_shape:\ttorch.Size([50, 10])\t\n",
"output_shape:\ttorch.Size([50, 10])\n",
"NodeFlops: 0\n",
"========================================\n"
]
}
],
"source": [
"v_maps = {}\n",
"total_flops = 0\n",
"\n",
"for node in gm.graph.nodes:\n",
" # print(f\"{node.target},\\t{node.op},\\t{node.meta['tensor_meta'].dtype},\\t{node.meta['tensor_meta'].shape}\")\n",
" print(f\"NodeOP:{node.op},\\tTarget:{node.target},\\tNodeName:{node.name},\\tNodeArgs:{node.args}\")\n",
" node_op_type = str(node.target).split(\".\")[-1]\n",
" node_flops = None\n",
" \n",
" input_shapes = [] \n",
" output_shapes = []\n",
" print(\"input_shape:\", end=\"\\t\")\n",
" for arg in node.args:\n",
" if str(arg) not in v_maps:\n",
" continue\n",
" print(f\"{v_maps[str(arg)]}\", end=\"\\t\")\n",
" input_shapes.append(v_maps[str(arg)])\n",
" print()\n",
" print(f\"output_shape:\\t{node.meta['tensor_meta'].shape}\")\n",
" output_shapes.append(node.meta['tensor_meta'].shape)\n",
" \n",
" if node.op in [\"output\", \"placeholder\"]:\n",
" node_flops = 0\n",
" elif node.op == \"call_function\":\n",
" # torch internal function\n",
" if str(node.target) in count_map:\n",
" node_flops = count_map[str(node.target)](input_shapes, output_shapes)\n",
" pass\n",
" elif node.op == \"call_method\":\n",
" # torch internal function\n",
" # print(str(node.target) in count_map, str(node.target), count_map.keys())\n",
" if str(node.target) in count_map:\n",
" node_flops = count_map[str(node.target)](input_shapes, output_shapes)\n",
" elif node.op == \"call_module\":\n",
" # torch.nn modules\n",
" m = getattr(net, node.target, None)\n",
" print(type(m), type(m) in count_map)\n",
" if type(m) in count_map:\n",
" node_flops = count_map[type(m)](input_shapes, output_shapes)\n",
" if node_op_type not in [\"relu\", \"maxpool\", \"avgpool\"]:\n",
" print(f\"weight_shape: {net.state_dict()[node.target + '.weight'].shape}\")\n",
" else:\n",
" print(f\"weight_shape: None\")\n",
" \n",
" v_maps[str(node.name)] = node.meta['tensor_meta'].shape\n",
"\n",
" print(f\"NodeFlops: {node_flops}\")\n",
" if node_flops is not None:\n",
" total_flops += node_flops\n",
" print(\"==\" * 20)"
]
},
{
"cell_type": "code",
"execution_count": 286,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5050500"
]
},
"execution_count": 286,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"total_flops"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment