Skip to content

Instantly share code, notes, and snippets.

@napoler
Last active September 27, 2021 14:15
Show Gist options
  • Save napoler/1e7b3e161344020e45c4671d4c20854d to your computer and use it in GitHub Desktop.
Save napoler/1e7b3e161344020e45c4671d4c20854d to your computer and use it in GitHub Desktop.
electra-pytorch基因序列示例.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"colab": {
"name": "electra-pytorch基因序列示例.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyMBUCZYkTgrM2v7iZaZHuYo",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.9 64-bit"
},
"language_info": {
"name": "python",
"version": "3.6.9",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"<a href=\"https://colab.research.google.com/gist/napoler/1e7b3e161344020e45c4671d4c20854d/electra-pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
],
"metadata": {
"id": "view-in-github",
"colab_type": "text"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [],
"outputs": [],
"metadata": {
"id": "p-kStpAYctni"
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "hHRDEv57cuQO"
}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"!pip install electra-pytorch\n",
"!pip install transformers\n",
"!pip install reformer_pytorch\n",
"\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Collecting electra-pytorch\n",
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/38/a4/b1484566695354028cfc4d01df520d8590e6fb0791594977f1bfdf275ebb/electra_pytorch-0.1.1-py3-none-any.whl\n",
"Collecting transformers==3.0.2\n",
"\u001b[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)\n",
"\u001b[K |████████████████████████████████| 778kB 1.7MB/s \n",
"\u001b[?25hProcessing /home/terry/.cache/pip/wheels/6f/cd/11/f6acd1062135d70bc0a7066808561580d256b3149055cb33ad/sklearn-0.0-py2.py3-none-any.whl\n",
"Collecting scipy\n",
"\u001b[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c8/89/63171228d5ced148f5ced50305c89e8576ffc695a90b58fe5bb602b910c2/scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl (25.9MB)\n",
"\u001b[K |████████████████████████████████| 25.9MB 12.1MB/s \n",
"\u001b[?25hCollecting torch>=1.6.0\n",
"\u001b[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/f1/735e39ed0e3877ff02ffe625989bb421747c3dfd256e37ed92ad32c986be/torch-1.9.1-cp36-cp36m-manylinux1_x86_64.whl (831.4MB)\n",
"\u001b[K |████████████████████████████████| 831.4MB 10kB/s \n",
"\u001b[?25hCollecting filelock\n",
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e0/a5/23349971aaf2bb56cf0bb084e51b4020098e53465c97eeb730e2e2a1da13/filelock-3.1.0-py2.py3-none-any.whl\n",
"Collecting tqdm>=4.27\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/63/f3/b7a1b8e40fd1bd049a34566eb353527bb9b8e9b98f8b6cf803bb64d8ce95/tqdm-4.62.3-py2.py3-none-any.whl\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2->electra-pytorch) (1.16.3)\n",
"Collecting sacremoses\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/36/bf/15f8df78bce5eee8223553123173f010d426565980e457c559a71ecbecc3/sacremoses-0.0.46-py3-none-any.whl\n",
"Collecting tokenizers==0.8.1.rc1\n",
"\u001b[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)\n",
"\u001b[K |████████████████████████████████| 3.0MB 13.2MB/s \n",
"\u001b[?25hRequirement already satisfied: packaging in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from transformers==3.0.2->electra-pytorch) (19.0)\n",
"Collecting sentencepiece!=0.1.92\n",
"\u001b[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/5b/49/2155d4078e9918003e77b6032a83d71995656bd05707d96e06a44cd6edf6/sentencepiece-0.1.96-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2MB)\n",
"\u001b[K |████████████████████████████████| 1.2MB 12.6MB/s \n",
"\u001b[?25hCollecting dataclasses; python_version < \"3.7\"\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/fe/ca/75fac5856ab5cfa51bbbcefa250182e50441074fdc3f803f6e76451fab43/dataclasses-0.8-py3-none-any.whl\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2->electra-pytorch) (2019.6.8)\n",
"Requirement already satisfied: requests in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from transformers==3.0.2->electra-pytorch) (2.21.0)\n",
"Collecting scikit-learn\n",
"\u001b[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d3/eb/d0e658465c029feb7083139d9ead51000742e88b1fb7f1504e19e1b4ce6e/scikit_learn-0.24.2-cp36-cp36m-manylinux2010_x86_64.whl (22.2MB)\n",
"\u001b[K |████████████████████████████████| 22.3MB 116kB/s \n",
"\u001b[?25hCollecting typing-extensions\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/74/60/18783336cc7fcdd95dae91d73477830aa53f5d3181ae4fe20491d7fc3199/typing_extensions-3.10.0.2-py3-none-any.whl\n",
"Collecting joblib\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/55/85/70c6602b078bd9e6f3da4f467047e906525c355a4dacd4f71b97a35d9897/joblib-1.0.1-py3-none-any.whl\n",
"Requirement already satisfied: click in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from sacremoses->transformers==3.0.2->electra-pytorch) (7.0)\n",
"Requirement already satisfied: six in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from sacremoses->transformers==3.0.2->electra-pytorch) (1.12.0)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from packaging->transformers==3.0.2->electra-pytorch) (2.4.0)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from requests->transformers==3.0.2->electra-pytorch) (2.8)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from requests->transformers==3.0.2->electra-pytorch) (2019.3.9)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from requests->transformers==3.0.2->electra-pytorch) (3.0.4)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from requests->transformers==3.0.2->electra-pytorch) (1.24.2)\n",
"Collecting threadpoolctl>=2.0.0\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/c6/e8/c216b9b60cbba4642d3ca1bae7a53daa0c24426f662e0e3ce3dc7f6caeaa/threadpoolctl-2.2.0-py3-none-any.whl\n",
"Installing collected packages: filelock, tqdm, joblib, sacremoses, tokenizers, sentencepiece, dataclasses, transformers, scipy, threadpoolctl, scikit-learn, sklearn, typing-extensions, torch, electra-pytorch\n",
"\u001b[31mERROR: Could not install packages due to an EnvironmentError: [Errno 13] 权限不够: '/usr/local/lib/python3.6/dist-packages/filelock'\n",
"Consider using the `--user` option or check the permissions.\n",
"\u001b[0m\n",
"\u001b[33mWARNING: You are using pip version 19.3.1; however, version 21.2.4 is available.\n",
"You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n",
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Collecting transformers\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/9b/9f/ec840879f2604d5c948858f605c55256f21010bca3c5705a344daafafbba/transformers-4.10.3-py3-none-any.whl\n",
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from transformers) (1.3.0)\n",
"Collecting huggingface-hub>=0.0.12\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/ab/45/2d908576740ee8876438cfa2a57aefefeb9f677a3df826d6b4d3c7e6b3cc/huggingface_hub-0.0.17-py3-none-any.whl\n",
"Collecting tqdm>=4.27\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/63/f3/b7a1b8e40fd1bd049a34566eb353527bb9b8e9b98f8b6cf803bb64d8ce95/tqdm-4.62.3-py2.py3-none-any.whl\n",
"Requirement already satisfied: pyyaml>=5.1 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from transformers) (5.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.6.8)\n",
"Requirement already satisfied: requests in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from transformers) (2.21.0)\n",
"Collecting tokenizers<0.11,>=0.10.1\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/bf/20/3605db440db4f96d5ffd66b231a043ae451ec7e5e4d1a2fb6f20608006c4/tokenizers-0.10.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl\n",
"Collecting dataclasses; python_version < \"3.7\"\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/fe/ca/75fac5856ab5cfa51bbbcefa250182e50441074fdc3f803f6e76451fab43/dataclasses-0.8-py3-none-any.whl\n",
"Requirement already satisfied: packaging in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from transformers) (19.0)\n",
"Collecting sacremoses\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/36/bf/15f8df78bce5eee8223553123173f010d426565980e457c559a71ecbecc3/sacremoses-0.0.46-py3-none-any.whl\n",
"Collecting numpy>=1.17\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/14/32/d3fa649ad7ec0b82737b92fefd3c4dd376b0bb23730715124569f38f3a08/numpy-1.19.5-cp36-cp36m-manylinux2010_x86_64.whl\n",
"Collecting filelock\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/e0/a5/23349971aaf2bb56cf0bb084e51b4020098e53465c97eeb730e2e2a1da13/filelock-3.1.0-py2.py3-none-any.whl\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers) (0.6.0)\n",
"Collecting typing-extensions\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/74/60/18783336cc7fcdd95dae91d73477830aa53f5d3181ae4fe20491d7fc3199/typing_extensions-3.10.0.2-py3-none-any.whl\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from requests->transformers) (2.8)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from requests->transformers) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from requests->transformers) (2019.3.9)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from requests->transformers) (1.24.2)\n",
"Requirement already satisfied: six in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from packaging->transformers) (1.12.0)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from packaging->transformers) (2.4.0)\n",
"Requirement already satisfied: click in /mnt/data/terry/terry/.local/lib/python3.6/site-packages (from sacremoses->transformers) (7.0)\n",
"Collecting joblib\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/55/85/70c6602b078bd9e6f3da4f467047e906525c355a4dacd4f71b97a35d9897/joblib-1.0.1-py3-none-any.whl\n",
"Requirement already satisfied: more-itertools in /usr/local/lib/python3.6/dist-packages (from zipp>=0.5->importlib-metadata; python_version < \"3.8\"->transformers) (8.0.2)\n",
"\u001b[31mERROR: huggingface-hub 0.0.17 has requirement packaging>=20.9, but you'll have packaging 19.0 which is incompatible.\u001b[0m\n",
"Installing collected packages: tqdm, typing-extensions, filelock, huggingface-hub, tokenizers, dataclasses, joblib, sacremoses, numpy, transformers\n",
"\u001b[31mERROR: Could not install packages due to an EnvironmentError: [Errno 13] 权限不够: '/usr/local/lib/python3.6/dist-packages/tqdm-4.62.3.dist-info'\n",
"Consider using the `--user` option or check the permissions.\n",
"\u001b[0m\n",
"\u001b[33mWARNING: You are using pip version 19.3.1; however, version 21.2.4 is available.\n",
"You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n",
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Collecting reformer_pytorch\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/6f/03/ed8ec108b14ef653d3ff9abbe9512d7858e61be286a270e684ae0fd211c0/reformer_pytorch-1.4.3-py3-none-any.whl\n",
"Collecting product-key-memory\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/31/3b/c1f8977e4b04f047acc7b23c7424d1e2e624ed7031e699a2ac2287af4c1f/product_key_memory-0.1.10.tar.gz\n",
"Collecting einops\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/1e/00/919421f097de2a6ca2d9b4d9f3f596274e44c243a6ecca210cd0811032c0/einops-0.3.2-py3-none-any.whl\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from reformer_pytorch) (1.1.0)\n",
"Collecting axial-positional-embedding>=0.1.0\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/7a/27/ad886f872b15153905d957a70670efe7521a07c70d324ff224f998e52492/axial_positional_embedding-0.2.1.tar.gz\n",
"Collecting local-attention\n",
" Using cached https://pypi.tuna.tsinghua.edu.cn/packages/c3/64/a555f10aa7258703235ac494448867b9be4048a452fe829840431bb64156/local_attention-1.4.3-py3-none-any.whl\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (1.16.3)\n",
"Installing collected packages: product-key-memory, einops, axial-positional-embedding, local-attention, reformer-pytorch\n",
" Running setup.py install for product-key-memory ... \u001b[?25lerror\n",
"\u001b[31m ERROR: Command errored out with exit status 1:\n",
" command: /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '\"'\"'/tmp/pip-install-k8xz2h_3/product-key-memory/setup.py'\"'\"'; __file__='\"'\"'/tmp/pip-install-k8xz2h_3/product-key-memory/setup.py'\"'\"';f=getattr(tokenize, '\"'\"'open'\"'\"', open)(__file__);code=f.read().replace('\"'\"'\\r\\n'\"'\"', '\"'\"'\\n'\"'\"');f.close();exec(compile(code, __file__, '\"'\"'exec'\"'\"'))' install --record /tmp/pip-record-krtz1ol7/install-record.txt --single-version-externally-managed --compile\n",
" cwd: /tmp/pip-install-k8xz2h_3/product-key-memory/\n",
" Complete output (11 lines):\n",
" running install\n",
" running build\n",
" running build_py\n",
" creating build\n",
" creating build/lib\n",
" creating build/lib/product_key_memory\n",
" copying product_key_memory/__init__.py -> build/lib/product_key_memory\n",
" copying product_key_memory/product_key_memory.py -> build/lib/product_key_memory\n",
" running install_lib\n",
" creating /usr/local/lib/python3.6/dist-packages/product_key_memory\n",
" error: could not create '/usr/local/lib/python3.6/dist-packages/product_key_memory': Permission denied\n",
" ----------------------------------------\u001b[0m\n",
"\u001b[31mERROR: Command errored out with exit status 1: /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '\"'\"'/tmp/pip-install-k8xz2h_3/product-key-memory/setup.py'\"'\"'; __file__='\"'\"'/tmp/pip-install-k8xz2h_3/product-key-memory/setup.py'\"'\"';f=getattr(tokenize, '\"'\"'open'\"'\"', open)(__file__);code=f.read().replace('\"'\"'\\r\\n'\"'\"', '\"'\"'\\n'\"'\"');f.close();exec(compile(code, __file__, '\"'\"'exec'\"'\"'))' install --record /tmp/pip-record-krtz1ol7/install-record.txt --single-version-externally-managed --compile Check the logs for full command output.\u001b[0m\n",
"\u001b[33mWARNING: You are using pip version 19.3.1; however, version 21.2.4 is available.\n",
"You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n",
"\u001b[?25h"
]
}
],
"metadata": {
"id": "wQ-4__g5cuZc"
}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"import torch\n",
"from torch import nn\n",
"from reformer_pytorch import ReformerLM\n",
"\n",
"from electra_pytorch import Electra\n",
"\n",
"# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator\n",
"\n",
"generator = ReformerLM(\n",
" num_tokens = 500,\n",
" emb_dim = 128,\n",
" dim = 128, # smaller hidden dimension\n",
" heads = 4, # less heads\n",
" ff_mult = 2, # smaller feed forward intermediate dimension\n",
" dim_head = 64,\n",
" depth = 12,\n",
" max_seq_len = 1024\n",
")\n",
"\n",
"discriminator = ReformerLM(\n",
" num_tokens = 500,\n",
" emb_dim = 128,\n",
" dim = 256,\n",
" dim_head = 64,\n",
" heads = 16,\n",
" depth = 12,\n",
" ff_mult = 4,\n",
" max_seq_len = 1024\n",
")\n",
"\n",
"# (2) weight tie the token and positional embeddings of generator and discriminator\n",
"\n",
"generator.token_emb = discriminator.token_emb\n",
"generator.pos_emb = discriminator.pos_emb\n",
"# weight tie any other embeddings if available, token type embeddings, etc.\n",
"\n",
"# (3) instantiate electra\n",
"\n",
"trainer = Electra(\n",
" generator,\n",
" discriminator,\n",
" discr_dim = 256, # the embedding dimension of the discriminator\n",
" discr_layer = 'reformer', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced\n",
" mask_token_id = 2, # the token id reserved for masking\n",
" pad_token_id = 0, # the token id for padding\n",
" mask_prob = 0.15, # masking probability for masked language modeling\n",
" mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep)\n",
")\n",
"\n",
"data = torch.randint(0, 500, (1, 128))\n",
"# (4) train\n",
"for it in range(10):\n",
" \n",
"\n",
" results = trainer(data)\n",
" results.loss.backward()\n",
" print(results.loss)\n",
"\n",
"# after much training, the discriminator should have improved\n",
"\n",
"torch.save(discriminator, f'./pretrained-model.pt')"
],
"outputs": [
{
"output_type": "error",
"ename": "ModuleNotFoundError",
"evalue": "No module named 'reformer_pytorch'",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-2-f321070409e0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mreformer_pytorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mReformerLM\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0melectra_pytorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mElectra\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'reformer_pytorch'"
]
}
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wPhBKjaCc7uQ",
"outputId": "9c744875-78b7-452e-ddd2-a8e81b755c93"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"results.loss"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(44.5330, grad_fn=<AddBackward0>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sIuAZ2DOd1Ln",
"outputId": "5e44b77e-8e09-4b1d-8368-71927a0be1a2"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"data"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[218, 472, 206, ..., 80, 334, 132]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GatBTzPIdynE",
"outputId": "4181348d-d336-4007-9f6e-02ab555e4836"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment