Created
February 19, 2020 00:49
-
-
Save AdityaSoni19031997/f877ebb73dd1b10c1758505eac08abae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Installs PyTorch, PyTorch/XLA, and Torchvision | |
# Copy this cell into your own notebooks to use PyTorch on Cloud TPUs | |
# Warning: this may take a couple minutes to run | |
import collections | |
from datetime import datetime, timedelta | |
import os | |
import requests | |
import threading | |
_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server') | |
VERSION = "torch_xla==nightly" #@param ["xrt==1.15.0", "torch_xla==nightly"] | |
CONFIG = { | |
'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'), | |
'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format( | |
(datetime.today() - timedelta(1)).strftime('%Y%m%d'))), | |
}[VERSION] | |
DIST_BUCKET = 'gs://tpu-pytorch/wheels' | |
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) | |
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) | |
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) | |
# Update TPU XRT version | |
def update_server_xrt(): | |
print('Updating server-side XRT to {} ...'.format(CONFIG.server)) | |
url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format( | |
TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0], | |
XRT_VERSION=CONFIG.server, | |
) | |
print(url) | |
print('Done updating server-side XRT: {}'.format(requests.post(url))) | |
update = threading.Thread(target=update_server_xrt) | |
update.start() | |
# Install Colab TPU compat PyTorch/TPU wheels and dependencies | |
!pip uninstall -y torch torchvision | |
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" . | |
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" . | |
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" . | |
!pip install "$TORCH_WHEEL" | |
!pip install "$TORCH_XLA_WHEEL" | |
!pip install "$TORCHVISION_WHEEL" | |
!pip install transformers | |
!sudo apt-get install libomp5 | |
update.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment