Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dhruvbird/65fd800983f362a72d78afe68031568c to your computer and use it in GitHub Desktop.
Save dhruvbird/65fd800983f362a72d78afe68031568c to your computer and use it in GitHub Desktop.
Simple Model for Demonstrating PyTorch's Tracing Based Selective Build
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNkubUgul95Kp02whwaY67b",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/dhruvbird/65fd800983f362a72d78afe68031568c/simple-model-for-demonstrating-pytorch-s-tracing-based-selective-build.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uHYA9J9I77-Q"
},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "markdown",
"source": [
"Create a simple model that has add/sub/mul operators to demonstrate building the PyTorch Runtime with support for just these operators. i.e. We can run just this model on the resulting PyTorch Runtime build using this custom configuration.\n"
],
"metadata": {
"id": "J_4H4hNN8SHj"
}
},
{
"cell_type": "code",
"source": [
"class SimplePyTorchModel(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.w = torch.zeros(3, 6)\n",
"\n",
" def forward(self, t1: torch.Tensor) -> torch.Tensor:\n",
" t2 = t1 * 2.0\n",
" t3 = t1 + t2\n",
" t4 = t3 * 5.0\n",
" t5 = t4 - t1\n",
" return t5"
],
"metadata": {
"id": "fGpmFq1Y8Rjs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Run the model and generate the output tensor\n",
"m = SimplePyTorchModel()\n",
"with torch.no_grad():\n",
" t1 = torch.ones(3, 6)\n",
" tret = m.forward(t1)"
],
"metadata": {
"id": "VSfKXZLx8ReF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Display the input and output tensors\n",
"t1, tret"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mbB2Is398RTb",
"outputId": "3b0fb19f-d360-4c25-b0b5-da09450c5ce9"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(tensor([[1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1.]]), tensor([[14., 14., 14., 14., 14., 14.],\n",
" [14., 14., 14., 14., 14., 14.],\n",
" [14., 14., 14., 14., 14., 14.]]))"
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"source": [
"# Generate a model that can be saved\n",
"scripted = torch.jit.script(m)"
],
"metadata": {
"id": "k6V5EYqX9RDg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Bundle Representative Inputs so that the model_tracer can run the model\n",
"# with these Representative Inputs\n",
"import torch.utils.bundled_inputs\n",
"\n",
"bundled_model_input = [\n",
" (torch.utils.bundled_inputs.bundle_large_tensor(t1), ),\n",
" (torch.utils.bundled_inputs.bundle_large_tensor(torch.rand(3, 6)), ),\n",
"]\n",
"\n",
"bundled = torch.utils.bundled_inputs.bundle_inputs(scripted, bundled_model_input)"
],
"metadata": {
"id": "4DldIoi19fQA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Save the model to the file system\n",
"bundled._save_for_lite_interpreter(\"/tmp/path_to_model.ptl\")"
],
"metadata": {
"id": "OvDxQ9mr_K9w"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Generate file /tmp/main.cpp\n",
"\n",
"tmp_main_cpp = \"\"\"\n",
"#include <iostream>\n",
"#include <ATen/ATen.h>\n",
"#include <torch/csrc/jit/mobile/import.h>\n",
"\n",
"\n",
"int main() {\n",
" auto m = torch::jit::_load_for_mobile(\"/tmp/path_to_model.ptl\");\n",
" auto res = m.forward({});\n",
" return 0;\n",
"}\n",
"\"\"\"\n",
"\n",
"with open(\"/tmp/main.cpp\", \"wt\") as f:\n",
" f.write(tmp_main_cpp)"
],
"metadata": {
"id": "IBVn9pzN_Y1j"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Set up the PyTorch Source directory\n",
"\n",
"Follow instructions on [this page](https://github.com/pytorch/pytorch#from-source) to set up PyTorch Locally\n",
"\n",
"Then, activate your Conda environment\n",
"\n",
"```\n",
"cd pytorch\n",
"conda activate\n",
"```\n"
],
"metadata": {
"id": "w_1YS2SRAZ88"
}
},
{
"cell_type": "markdown",
"source": [
"# Build PyTorch Non-Selectively\n",
"\n",
"Commands to build PyTorch with the CPU backend and check the size of the binary\n",
"\n",
"#### Build PyTorch w/o selective build\n",
"```\n",
"BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 \\\n",
" USE_LIGHTWEIGHT_DISPATCH=0 \\\n",
" BUILD_LITE_INTERPRETER=1 \\\n",
" ./scripts/build_mobile.sh\n",
"```\n",
"\n",
"#### Build sample binary using this command\n",
"```\n",
"g++ /tmp/main.cpp -L build_mobile/lib/ \\\n",
" -I build_mobile/install/include/ \\\n",
" -ffunction-sections -fdata-sections -Wl,--gc-sections \\\n",
" -lpthread -lc10 \\\n",
" -Wl,--whole-archive -ltorch_cpu -Wl,--no-whole-archive \\\n",
" -ltorch -lXNNPACK -lpytorch_qnnpack -lcpuinfo -lclog \\\n",
" -lpthreadpool -lkineto -lfmt -ldl -lc10\n",
"```\n",
"\n",
"#### Check the size of this binary (unstripped and stripped)\n",
"```\n",
"[ ~/pytorch] ls -lah a.out\n",
"-rwxr-xr-x 1 username users 49M Sep 9 11:22 a.out\n",
"[ ~/pytorch] strip a.out\n",
"[ ~/pytorch] ls -lah a.out\n",
"-rwxr-xr-x 1 username users 34M Sep 9 11:23 a.out\n",
"```\n",
"\n",
"\n"
],
"metadata": {
"id": "hHFFOFkvAH9R"
}
},
{
"cell_type": "markdown",
"source": [
"# Build the PyTorch Model Tracer\n",
"\n",
"#### Clean out cached configuration\n",
"```\n",
"make clean\n",
"```\n",
"\n",
"#### Build Model Tracer Binary\n",
"```\n",
"USE_NUMPY=0 \\\n",
" USE_DISTRIBUTED=0 \\\n",
" USE_CUDA=0 \\\n",
" TRACING_BASED=1 \\\n",
" python setup.py develop --cmake\n",
"```\n",
"\n",
"#### Run the model tracer on a .ptl (PyTorch Lite Interpreter) model to produce the YAML file containing the set of used operators.\n",
"```\n",
"build/bin/model_tracer \\\n",
" --model_input_path /tmp/path_to_model.ptl \\\n",
" --build_yaml_path /tmp/selected_ops.yaml\n",
"```\n",
"\n"
],
"metadata": {
"id": "dINUnJP0Y7oj"
}
},
{
"cell_type": "markdown",
"source": [
"# Build PyTorch Selectively\n",
"\n",
"#### Clean out cached configuration\n",
"```\n",
"make clean\n",
"```\n",
"\n",
"#### Build PyTorch using Selected Operators (from the YAML file) using the host toolchain\n",
"```\n",
"BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 \\\n",
" USE_LIGHTWEIGHT_DISPATCH=0 \\\n",
" BUILD_LITE_INTERPRETER=1 \\\n",
" SELECTED_OP_LIST=/tmp/selected_ops.yaml \\\n",
" TRACING_BASED=1 \\\n",
" ./scripts/build_mobile.sh\n",
"```\n",
"\n",
"#### Build sample binary using this command\n",
"```\n",
"g++ /tmp/main.cpp -L build_mobile/lib/ \\\n",
" -I build_mobile/install/include/ \\\n",
" -ffunction-sections -fdata-sections -Wl,--gc-sections \\\n",
" -lpthread -lc10 \\\n",
" -Wl,--whole-archive -ltorch_cpu -Wl,--no-whole-archive \\\n",
" -ltorch -lXNNPACK -lpytorch_qnnpack -lcpuinfo -lclog \\\n",
" -lpthreadpool -lkineto -lfmt -ldl -lc10\n",
"```\n",
"\n",
"#### Check the size of this binary (unstripped and stripped)\n",
"```\n",
"[ ~/pytorch] ls -lah a.out\n",
"-rwxr-xr-x 1 username users 3.7M Sep 9 16:33 a.out\n",
"[ ~/pytorch] strip a.out\n",
"[ ~/pytorch] ls -lah a.out\n",
"-rwxr-xr-x 1 username users 2.7M Sep 9 16:44 a.out\n",
"```\n"
],
"metadata": {
"id": "JZE1ZITLdaR1"
}
},
{
"cell_type": "code",
"source": [
"print(\"Unstripped binary size using Selective Build is {:.1f}% of Original\".format(3.7/49*100.0))\n",
"print(\"Stripped binary size using Selective Build is {:.1f}% of Original\".format(2.7/34*100.0))\n",
"\n",
"print(\"Size saving for unstripped binary is {:.1f}%\".format(100.0 - 3.7/49*100.0))\n",
"print(\"Size saving for stripped binary is {:.1f}%\".format(100.0 - 2.7/34*100.0))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PwtmYk7lA2Ob",
"outputId": "8634da54-1e5f-4cba-858c-bc624bf0c8e3"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Unstripped binary size using Selective Build is 7.6% of Original\n",
"Stripped binary size using Selective Build is 7.9% of Original\n",
"Size saving for unstripped binary is 92.4%\n",
"Size saving for stripped binary is 92.1%\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Contents of /tmp/selected_ops.yaml (operators, dtypes, custom classes)\n",
"[gist](https://gist.github.com/dhruvbird/50e1860b39ae065e57d58f17e0912136)\n",
"\n",
"```\n",
"include_all_non_op_selectives: false\n",
"build_features: []\n",
"operators:\n",
" aten::__getitem__.t:\n",
" is_used_for_training: false\n",
" is_root_operator: true\n",
" include_all_overloads: false\n",
" aten::_set_item.str:\n",
" is_used_for_training: false\n",
" is_root_operator: true\n",
" include_all_overloads: false\n",
" aten::add.Tensor:\n",
" is_used_for_training: false\n",
" is_root_operator: true\n",
" include_all_overloads: false\n",
" aten::len.t:\n",
" is_used_for_training: false\n",
" is_root_operator: true\n",
" include_all_overloads: false\n",
" aten::mul.Scalar:\n",
" is_used_for_training: false\n",
" is_root_operator: true\n",
" include_all_overloads: false\n",
" aten::sub.Tensor:\n",
" is_used_for_training: false\n",
" is_root_operator: true\n",
" include_all_overloads: false\n",
" aten::_local_scalar_dense:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::_to_copy:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::as_strided:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::copy_:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::div.Scalar:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::div.Tensor:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::empty.memory_format:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::empty_strided:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::eq.Tensor:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::fill_.Scalar:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::item:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::mul.Tensor:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::narrow:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::new_empty_strided:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::ones:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::resize_:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::set_.source_Storage:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::set_.source_Storage_storage_offset:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::slice.Tensor:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::to.dtype:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::to.dtype_layout:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::zero_:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
" aten::zeros:\n",
" is_used_for_training: false\n",
" is_root_operator: false\n",
" include_all_overloads: false\n",
"kernel_metadata:\n",
" _local_scalar_dense_cpu:\n",
" - Float\n",
" add_stub:\n",
" - Float\n",
" copy_:\n",
" - Bool\n",
" - Byte\n",
" - Char\n",
" - Double\n",
" - Float\n",
" - Int\n",
" - Long\n",
" - Short\n",
" copy_kernel:\n",
" - Bool\n",
" - Float\n",
" - Int\n",
" div_cpu:\n",
" - Float\n",
" eq_cpu:\n",
" - Float\n",
" fill_cpu:\n",
" - Float\n",
" fill_out:\n",
" - Double\n",
" - Long\n",
" mul_cpu:\n",
" - Float\n",
"custom_classes: []\n",
"```"
],
"metadata": {
"id": "JjizEwhgCTnB"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment