Last active
August 16, 2022 19:40
-
-
Save hjm-aws/6faf373a6cafac00439815983af9f0a1 to your computer and use it in GitHub Desktop.
Getting Started with PyTorch on Cloud TPUs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/hjm-aws/6faf373a6cafac00439815983af9f0a1/embedding-output-is-fp32-even-after-the-embedding-weight-is-casted-to-bf16.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "3P6b3uqfzpDI" | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "CHzziBW5AoZH" | |
}, | |
"source": [ | |
"## Installing PyTorch/XLA" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"id": "OApBOAe1fpH_", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "1be3954b-a79d-46d6-9b7e-478046e871aa" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting torch-xla==1.12\n", | |
" Downloading https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl (187.4 MB)\n", | |
"\u001b[K |████████████████████████████████| 187.4 MB 29 kB/s \n", | |
"\u001b[?25hCollecting cloud-tpu-client==0.10\n", | |
" Downloading cloud_tpu_client-0.10-py3-none-any.whl (7.4 kB)\n", | |
"Collecting torch==1.12.0\n", | |
" Downloading torch-1.12.0-cp37-cp37m-manylinux1_x86_64.whl (776.3 MB)\n", | |
"\u001b[K |████████████████████████████████| 776.3 MB 16 kB/s \n", | |
"\u001b[?25hRequirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from torch-xla==1.12) (1.2.0)\n", | |
"Requirement already satisfied: oauth2client in /usr/local/lib/python3.7/dist-packages (from cloud-tpu-client==0.10) (4.1.3)\n", | |
"Collecting google-api-python-client==1.8.0\n", | |
" Downloading google_api_python_client-1.8.0-py3-none-any.whl (57 kB)\n", | |
"\u001b[K |████████████████████████████████| 57 kB 6.3 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.12.0) (4.1.1)\n", | |
"Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.17.4)\n", | |
"Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.35.0)\n", | |
"Requirement already satisfied: six<2dev,>=1.6.1 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.15.0)\n", | |
"Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.1)\n", | |
"Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.0.4)\n", | |
"Requirement already satisfied: google-api-core<2dev,>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.31.6)\n", | |
"Requirement already satisfied: setuptools>=40.3.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (57.4.0)\n", | |
"Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.56.4)\n", | |
"Requirement already satisfied: pytz in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2022.1)\n", | |
"Requirement already satisfied: packaging>=14.3 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (21.3)\n", | |
"Requirement already satisfied: protobuf<4.0.0dev,>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.17.3)\n", | |
"Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.23.0)\n", | |
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.2.8)\n", | |
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (4.2.4)\n", | |
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (4.9)\n", | |
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=14.3->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.9)\n", | |
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.4.8)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.10)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2022.6.15)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.24.3)\n", | |
"Installing collected packages: google-api-python-client, cloud-tpu-client, torch-xla, torch\n", | |
" Attempting uninstall: google-api-python-client\n", | |
" Found existing installation: google-api-python-client 1.12.11\n", | |
" Uninstalling google-api-python-client-1.12.11:\n", | |
" Successfully uninstalled google-api-python-client-1.12.11\n", | |
" Attempting uninstall: torch\n", | |
" Found existing installation: torch 1.12.1+cu113\n", | |
" Uninstalling torch-1.12.1+cu113:\n", | |
" Successfully uninstalled torch-1.12.1+cu113\n", | |
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | |
"torchvision 0.13.1+cu113 requires torch==1.12.1, but you have torch 1.12.0 which is incompatible.\n", | |
"torchtext 0.13.1 requires torch==1.12.1, but you have torch 1.12.0 which is incompatible.\n", | |
"torchaudio 0.12.1+cu113 requires torch==1.12.1, but you have torch 1.12.0 which is incompatible.\n", | |
"earthengine-api 0.1.318 requires google-api-python-client<2,>=1.12.1, but you have google-api-python-client 1.8.0 which is incompatible.\u001b[0m\n", | |
"Successfully installed cloud-tpu-client-0.10 google-api-python-client-1.8.0 torch-1.12.0 torch-xla-1.12\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ls3j-EWI2D2v" | |
}, | |
"source": [ | |
"## Repro\n", | |
"\n", | |
"The output dtype of an embedding module is fp32 even after the embedding weight is casted to bf16." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"id": "42avAvSg17by", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "f161f02c-ea0c-45c3-cd4e-9960f5dbff9e" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Embedding created as bf16, output dtype: torch.bfloat16\n", | |
"Embedding created as fp32 and casted to bf16, output dtype: torch.float32\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch_xla.core.xla_model as xm\n", | |
"\n", | |
"device = xm.xla_device()\n", | |
"index = torch.ones(1, dtype=torch.long, device=device)\n", | |
"\n", | |
"emb = torch.nn.Embedding(1024, 128, dtype=torch.bfloat16, device=device)\n", | |
"emb_out = emb(index)\n", | |
"print(f'Embedding created as bf16, output dtype: {emb_out.dtype}', flush=True)\n", | |
"\n", | |
"emb = torch.nn.Embedding(1024, 128, device=device)\n", | |
"emb = emb.to(torch.bfloat16)\n", | |
"emb_out = emb(index)\n", | |
"print(f'Embedding created as fp32 and casted to bf16, output dtype: {emb_out.dtype}', flush=True)\n" | |
] | |
} | |
], | |
"metadata": { | |
"accelerator": "TPU", | |
"colab": { | |
"collapsed_sections": [], | |
"machine_shape": "hm", | |
"name": "Getting Started with PyTorch on Cloud TPUs", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment