Created
October 2, 2021 12:23
-
-
Save Lyken17/60dec0b4c64db6f85918179eca35f432 to your computer and use it in GitHub Desktop.
Use torch.fx to count FLOPs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 266, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"\n", | |
"from torch.fx import symbolic_trace" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 277, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Simple module for demonstration\n", | |
"\n", | |
"class MyOP(nn.Module):\n", | |
" def forward(self, input):\n", | |
" return input - 1\n", | |
"\n", | |
"class MyModule(torch.nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.param = torch.nn.Parameter(torch.rand(3, 4))\n", | |
" self.linear = torch.nn.Linear(4, 5)\n", | |
" self.linear2 = torch.nn.Linear(5, 5)\n", | |
" self.myop = MyOP()\n", | |
"\n", | |
" def forward(self, x):\n", | |
" out = self.linear(x + self.param)\n", | |
" out = out ** 2\n", | |
" out = self.myop(out)\n", | |
" return self.linear2(x).clamp(min=0.0, max=1.0)\n", | |
"\n", | |
"\n", | |
"module = MyModule()\n", | |
"# Symbolic tracing frontend - captures the semantics of the module\n", | |
"symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 278, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"graph():\n", | |
" %x : [#users=2] = placeholder[target=x]\n", | |
" %param : [#users=1] = get_attr[target=param]\n", | |
" %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})\n", | |
" %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})\n", | |
" %pow_1 : [#users=1] = call_function[target=operator.pow](args = (%linear, 2), kwargs = {})\n", | |
" %sub : [#users=0] = call_function[target=operator.sub](args = (%pow_1, 1), kwargs = {})\n", | |
" %linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {})\n", | |
" %clamp : [#users=1] = call_method[target=clamp](args = (%linear2,), kwargs = {min: 0.0, max: 1.0})\n", | |
" return clamp\n" | |
] | |
} | |
], | |
"source": [ | |
"print(symbolic_traced.graph)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 279, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"g = symbolic_traced.graph" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 280, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"opcode name target args kwargs\n", | |
"------------- ------- ----------------------- ----------- ------------------------\n", | |
"placeholder x x () {}\n", | |
"get_attr param param () {}\n", | |
"call_function add <built-in function add> (x, param) {}\n", | |
"call_module linear linear (add,) {}\n", | |
"call_function pow_1 <built-in function pow> (linear, 2) {}\n", | |
"call_function sub <built-in function sub> (pow_1, 1) {}\n", | |
"call_module linear2 linear2 (x,) {}\n", | |
"call_method clamp clamp (linear2,) {'min': 0.0, 'max': 1.0}\n", | |
"output output output (clamp,) {}\n" | |
] | |
} | |
], | |
"source": [ | |
"g.print_tabular()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 281, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"call_module linear (add,)\n", | |
"call_module linear2 (x,)\n", | |
"call_method clamp (linear2,)\n" | |
] | |
} | |
], | |
"source": [ | |
"for node in g.nodes:\n", | |
" if node.op in [\"call_module\", \"call_method\"]:\n", | |
" print(node.op, node.target, node.args)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 199, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"def forward(self, x):\n", | |
" param = self.param\n", | |
" add = x + param; param = None\n", | |
" linear = self.linear(add); add = None\n", | |
" linear2 = self.linear2(x); x = None\n", | |
" clamp = linear2.clamp(min = 0.0, max = 1.0); linear2 = None\n", | |
" return clamp\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"print(symbolic_traced.code)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 200, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl" | |
] | |
}, | |
"execution_count": 200, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"type(symbolic_traced)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 201, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch.fx.passes.shape_prop import ShapeProp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 282, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"TwoLayerNet(\n", | |
" (linear1): Linear(in_features=1000, out_features=100, bias=True)\n", | |
" (linear2): Linear(in_features=100, out_features=10, bias=True)\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"class TwoLayerNet(torch.nn.Module):\n", | |
" def __init__(self, D_in, H, D_out):\n", | |
" super(TwoLayerNet, self).__init__()\n", | |
" self.linear1 = torch.nn.Linear(D_in, H)\n", | |
" self.linear2 = torch.nn.Linear(H, D_out)\n", | |
" def forward(self, x):\n", | |
" h_relu = self.linear1(x).clamp(min=0)\n", | |
" y_pred = self.linear2(h_relu).clamp(min=0)\n", | |
" \n", | |
" return y_pred * y_pred\n", | |
"\n", | |
"N, D_in, H, D_out = 64, 1000, 100, 10\n", | |
"x = torch.randn(N, D_in)\n", | |
"y = torch.randn(N, D_out)\n", | |
"net = TwoLayerNet(D_in, H, D_out)\n", | |
"gm = torch.fx.symbolic_trace(net)\n", | |
"sample_input = torch.randn(50, D_in)\n", | |
"ShapeProp(gm).propagate(sample_input);\n", | |
"print(net)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 283, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"x,\tplaceholder,\t()\n", | |
"input_shape:\t\n", | |
"output_shape:\ttorch.Size([50, 1000])\n", | |
"========================================\n", | |
"linear1,\tcall_module,\t(x,)\n", | |
"weight_shape: torch.Size([100, 1000])\n", | |
"input_shape:\ttorch.Size([50, 1000])\t\n", | |
"output_shape:\ttorch.Size([50, 100])\n", | |
"========================================\n", | |
"clamp,\tcall_method,\t(linear1,)\n", | |
"input_shape:\ttorch.Size([50, 100])\t\n", | |
"output_shape:\ttorch.Size([50, 100])\n", | |
"========================================\n", | |
"linear2,\tcall_module,\t(clamp,)\n", | |
"weight_shape: torch.Size([10, 100])\n", | |
"input_shape:\ttorch.Size([50, 100])\t\n", | |
"output_shape:\ttorch.Size([50, 10])\n", | |
"========================================\n", | |
"clamp,\tcall_method,\t(linear2,)\n", | |
"input_shape:\ttorch.Size([50, 10])\t\n", | |
"output_shape:\ttorch.Size([50, 10])\n", | |
"========================================\n", | |
"<built-in function mul>,\tcall_function,\t(clamp_1, clamp_1)\n", | |
"input_shape:\t\n", | |
"output_shape:\ttorch.Size([50, 10])\n", | |
"========================================\n", | |
"output,\toutput,\t(mul,)\n", | |
"input_shape:\t\n", | |
"output_shape:\ttorch.Size([50, 10])\n", | |
"========================================\n" | |
] | |
} | |
], | |
"source": [ | |
"v_maps = {}\n", | |
"\n", | |
"for node in gm.graph.nodes:\n", | |
" # print(f\"{node.target},\\t{node.op},\\t{node.meta['tensor_meta'].dtype},\\t{node.meta['tensor_meta'].shape}\")\n", | |
" print(f\"{node.target},\\t{node.op},\\t{node.args}\")\n", | |
" node_op_type = str(node.target).split(\".\")[-1]\n", | |
" \n", | |
" if node.op == \"call_function\":\n", | |
" pass\n", | |
" elif node.op == \"call_method\":\n", | |
" pass\n", | |
" elif node.op == \"call_module\":\n", | |
" if node_op_type not in [\"relu\", \"maxpool\", \"avgpool\"]:\n", | |
" print(f\"weight_shape: {net.state_dict()[node.target + '.weight'].shape}\")\n", | |
" else:\n", | |
" print(f\"weight_shape: None\")\n", | |
" print(\"input_shape:\", end=\"\\t\")\n", | |
" for arg in node.args:\n", | |
" if str(arg) not in v_maps:\n", | |
" continue\n", | |
" print(f\"{v_maps[str(arg)]}\", end=\"\\t\")\n", | |
" print()\n", | |
" print(f\"output_shape:\\t{node.meta['tensor_meta'].shape}\")\n", | |
" v_maps[str(node.target)] = node.meta['tensor_meta'].shape\n", | |
" print(\"==\" * 20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 284, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def count_clamp(input_shapes, output_shapes):\n", | |
" return 0\n", | |
"\n", | |
"def count_mul(input_shapes, output_shapes):\n", | |
" # element-wise\n", | |
" return output_shapes[0].numel()\n", | |
"\n", | |
"def count_nn_linear(input_shapes, output_shapes):\n", | |
" in_shape = input_shapes[0]\n", | |
" out_shape = output_shapes[0]\n", | |
" in_features = in_shape[-1]\n", | |
" num_elements = out_shape.numel()\n", | |
" return in_features * num_elements\n", | |
"\n", | |
"count_map = {\n", | |
" nn.Linear: count_nn_linear,\n", | |
" \"clamp\": count_clamp,\n", | |
" \"<built-in function mul>\": count_mul,\n", | |
"}\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 285, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"NodeOP:placeholder,\tTarget:x,\tNodeName:x,\tNodeArgs:()\n", | |
"input_shape:\t\n", | |
"output_shape:\ttorch.Size([50, 1000])\n", | |
"NodeFlops: 0\n", | |
"========================================\n", | |
"NodeOP:call_module,\tTarget:linear1,\tNodeName:linear1,\tNodeArgs:(x,)\n", | |
"input_shape:\ttorch.Size([50, 1000])\t\n", | |
"output_shape:\ttorch.Size([50, 100])\n", | |
"<class 'torch.nn.modules.linear.Linear'> True\n", | |
"weight_shape: torch.Size([100, 1000])\n", | |
"NodeFlops: 5000000\n", | |
"========================================\n", | |
"NodeOP:call_method,\tTarget:clamp,\tNodeName:clamp,\tNodeArgs:(linear1,)\n", | |
"input_shape:\ttorch.Size([50, 100])\t\n", | |
"output_shape:\ttorch.Size([50, 100])\n", | |
"NodeFlops: 0\n", | |
"========================================\n", | |
"NodeOP:call_module,\tTarget:linear2,\tNodeName:linear2,\tNodeArgs:(clamp,)\n", | |
"input_shape:\ttorch.Size([50, 100])\t\n", | |
"output_shape:\ttorch.Size([50, 10])\n", | |
"<class 'torch.nn.modules.linear.Linear'> True\n", | |
"weight_shape: torch.Size([10, 100])\n", | |
"NodeFlops: 50000\n", | |
"========================================\n", | |
"NodeOP:call_method,\tTarget:clamp,\tNodeName:clamp_1,\tNodeArgs:(linear2,)\n", | |
"input_shape:\ttorch.Size([50, 10])\t\n", | |
"output_shape:\ttorch.Size([50, 10])\n", | |
"NodeFlops: 0\n", | |
"========================================\n", | |
"NodeOP:call_function,\tTarget:<built-in function mul>,\tNodeName:mul,\tNodeArgs:(clamp_1, clamp_1)\n", | |
"input_shape:\ttorch.Size([50, 10])\ttorch.Size([50, 10])\t\n", | |
"output_shape:\ttorch.Size([50, 10])\n", | |
"NodeFlops: 500\n", | |
"========================================\n", | |
"NodeOP:output,\tTarget:output,\tNodeName:output,\tNodeArgs:(mul,)\n", | |
"input_shape:\ttorch.Size([50, 10])\t\n", | |
"output_shape:\ttorch.Size([50, 10])\n", | |
"NodeFlops: 0\n", | |
"========================================\n" | |
] | |
} | |
], | |
"source": [ | |
"v_maps = {}\n", | |
"total_flops = 0\n", | |
"\n", | |
"for node in gm.graph.nodes:\n", | |
" # print(f\"{node.target},\\t{node.op},\\t{node.meta['tensor_meta'].dtype},\\t{node.meta['tensor_meta'].shape}\")\n", | |
" print(f\"NodeOP:{node.op},\\tTarget:{node.target},\\tNodeName:{node.name},\\tNodeArgs:{node.args}\")\n", | |
" node_op_type = str(node.target).split(\".\")[-1]\n", | |
" node_flops = None\n", | |
" \n", | |
" input_shapes = [] \n", | |
" output_shapes = []\n", | |
" print(\"input_shape:\", end=\"\\t\")\n", | |
" for arg in node.args:\n", | |
" if str(arg) not in v_maps:\n", | |
" continue\n", | |
" print(f\"{v_maps[str(arg)]}\", end=\"\\t\")\n", | |
" input_shapes.append(v_maps[str(arg)])\n", | |
" print()\n", | |
" print(f\"output_shape:\\t{node.meta['tensor_meta'].shape}\")\n", | |
" output_shapes.append(node.meta['tensor_meta'].shape)\n", | |
" \n", | |
" if node.op in [\"output\", \"placeholder\"]:\n", | |
" node_flops = 0\n", | |
" elif node.op == \"call_function\":\n", | |
" # torch internal function\n", | |
" if str(node.target) in count_map:\n", | |
" node_flops = count_map[str(node.target)](input_shapes, output_shapes)\n", | |
" pass\n", | |
" elif node.op == \"call_method\":\n", | |
" # torch internal function\n", | |
" # print(str(node.target) in count_map, str(node.target), count_map.keys())\n", | |
" if str(node.target) in count_map:\n", | |
" node_flops = count_map[str(node.target)](input_shapes, output_shapes)\n", | |
" elif node.op == \"call_module\":\n", | |
" # torch.nn modules\n", | |
" m = getattr(net, node.target, None)\n", | |
" print(type(m), type(m) in count_map)\n", | |
" if type(m) in count_map:\n", | |
" node_flops = count_map[type(m)](input_shapes, output_shapes)\n", | |
" if node_op_type not in [\"relu\", \"maxpool\", \"avgpool\"]:\n", | |
" print(f\"weight_shape: {net.state_dict()[node.target + '.weight'].shape}\")\n", | |
" else:\n", | |
" print(f\"weight_shape: None\")\n", | |
" \n", | |
" v_maps[str(node.name)] = node.meta['tensor_meta'].shape\n", | |
"\n", | |
" print(f\"NodeFlops: {node_flops}\")\n", | |
" if node_flops is not None:\n", | |
" total_flops += node_flops\n", | |
" print(\"==\" * 20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 286, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"5050500" | |
] | |
}, | |
"execution_count": 286, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"total_flops" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment