Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save hjm-aws/6faf373a6cafac00439815983af9f0a1 to your computer and use it in GitHub Desktop.
Save hjm-aws/6faf373a6cafac00439815983af9f0a1 to your computer and use it in GitHub Desktop.
Getting Started with PyTorch on Cloud TPUs
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/hjm-aws/6faf373a6cafac00439815983af9f0a1/embedding-output-is-fp32-even-after-the-embedding-weight-is-casted-to-bf16.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "3P6b3uqfzpDI"
},
"outputs": [],
"source": [
"import os\n",
"assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CHzziBW5AoZH"
},
"source": [
"## Installing PyTorch/XLA"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "OApBOAe1fpH_",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1be3954b-a79d-46d6-9b7e-478046e871aa"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting torch-xla==1.12\n",
" Downloading https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl (187.4 MB)\n",
"\u001b[K |████████████████████████████████| 187.4 MB 29 kB/s \n",
"\u001b[?25hCollecting cloud-tpu-client==0.10\n",
" Downloading cloud_tpu_client-0.10-py3-none-any.whl (7.4 kB)\n",
"Collecting torch==1.12.0\n",
" Downloading torch-1.12.0-cp37-cp37m-manylinux1_x86_64.whl (776.3 MB)\n",
"\u001b[K |████████████████████████████████| 776.3 MB 16 kB/s \n",
"\u001b[?25hRequirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from torch-xla==1.12) (1.2.0)\n",
"Requirement already satisfied: oauth2client in /usr/local/lib/python3.7/dist-packages (from cloud-tpu-client==0.10) (4.1.3)\n",
"Collecting google-api-python-client==1.8.0\n",
" Downloading google_api_python_client-1.8.0-py3-none-any.whl (57 kB)\n",
"\u001b[K |████████████████████████████████| 57 kB 6.3 MB/s \n",
"\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.12.0) (4.1.1)\n",
"Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.17.4)\n",
"Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.35.0)\n",
"Requirement already satisfied: six<2dev,>=1.6.1 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.15.0)\n",
"Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.1)\n",
"Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.0.4)\n",
"Requirement already satisfied: google-api-core<2dev,>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.31.6)\n",
"Requirement already satisfied: setuptools>=40.3.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (57.4.0)\n",
"Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.56.4)\n",
"Requirement already satisfied: pytz in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2022.1)\n",
"Requirement already satisfied: packaging>=14.3 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (21.3)\n",
"Requirement already satisfied: protobuf<4.0.0dev,>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.17.3)\n",
"Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.23.0)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.2.8)\n",
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (4.2.4)\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (4.9)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=14.3->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.9)\n",
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.4.8)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2022.6.15)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.24.3)\n",
"Installing collected packages: google-api-python-client, cloud-tpu-client, torch-xla, torch\n",
" Attempting uninstall: google-api-python-client\n",
" Found existing installation: google-api-python-client 1.12.11\n",
" Uninstalling google-api-python-client-1.12.11:\n",
" Successfully uninstalled google-api-python-client-1.12.11\n",
" Attempting uninstall: torch\n",
" Found existing installation: torch 1.12.1+cu113\n",
" Uninstalling torch-1.12.1+cu113:\n",
" Successfully uninstalled torch-1.12.1+cu113\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"torchvision 0.13.1+cu113 requires torch==1.12.1, but you have torch 1.12.0 which is incompatible.\n",
"torchtext 0.13.1 requires torch==1.12.1, but you have torch 1.12.0 which is incompatible.\n",
"torchaudio 0.12.1+cu113 requires torch==1.12.1, but you have torch 1.12.0 which is incompatible.\n",
"earthengine-api 0.1.318 requires google-api-python-client<2,>=1.12.1, but you have google-api-python-client 1.8.0 which is incompatible.\u001b[0m\n",
"Successfully installed cloud-tpu-client-0.10 google-api-python-client-1.8.0 torch-1.12.0 torch-xla-1.12\n"
]
}
],
"source": [
"!pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ls3j-EWI2D2v"
},
"source": [
"## Repro\n",
"\n",
"The output dtype of an embedding module is fp32 even after the embedding weight is casted to bf16."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "42avAvSg17by",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f161f02c-ea0c-45c3-cd4e-9960f5dbff9e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Embedding created as bf16, output dtype: torch.bfloat16\n",
"Embedding created as fp32 and casted to bf16, output dtype: torch.float32\n"
]
}
],
"source": [
"import torch\n",
"import torch_xla.core.xla_model as xm\n",
"\n",
"device = xm.xla_device()\n",
"index = torch.ones(1, dtype=torch.long, device=device)\n",
"\n",
"emb = torch.nn.Embedding(1024, 128, dtype=torch.bfloat16, device=device)\n",
"emb_out = emb(index)\n",
"print(f'Embedding created as bf16, output dtype: {emb_out.dtype}', flush=True)\n",
"\n",
"emb = torch.nn.Embedding(1024, 128, device=device)\n",
"emb = emb.to(torch.bfloat16)\n",
"emb_out = emb(index)\n",
"print(f'Embedding created as fp32 and casted to bf16, output dtype: {emb_out.dtype}', flush=True)\n"
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "Getting Started with PyTorch on Cloud TPUs",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment