Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save mauicv/75b8a40edafc96c0f9d6a84b16f3c708 to your computer and use it in GitHub Desktop.
Save mauicv/75b8a40edafc96c0f9d6a84b16f3c708 to your computer and use it in GitHub Desktop.
Transformer-Token-Choice-MoE-shakesphere-char.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "A100",
"machine_shape": "hm",
"authorship_tag": "ABX9TyMtnzK/yele0ZtwFKR5kJWj",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/mauicv/75b8a40edafc96c0f9d6a84b16f3c708/transformer-token-choice-moe-shakesphere-char.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "26UTHYvigRm_",
"outputId": "4020e4cb-c03d-4ea0-dce9-822f1ab0440b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"source": [
"!pip install -q git+https://github.com/mauicv/transformers@feature/token-choice-moe"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ULd-3x193Olb",
"outputId": "eb94442a-38c8-42f2-f2e1-f44e58059b74"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for pytfex (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install -q tokenizers\n",
"!pip install -q tiktoken\n",
"!pip install -q livelossplot"
],
"metadata": {
"id": "QQB3fgwIgdQW",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9bbcd3b4-87fd-4d5f-d7eb-8f4457635d4a"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"llmx 0.0.15a0 requires cohere, which is not installed.\n",
"llmx 0.0.15a0 requires openai, which is not installed.\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
]
},
{
"cell_type": "code",
"source": [
"# The following is taken from https://github.com/karpathy/nanoGPT/tree/master/data/shakespeare_char\n",
"\n",
"\"\"\"\n",
"Prepare the Shakespeare dataset for character-level language modeling.\n",
"So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.\n",
"Will save train.bin, val.bin containing the ids, and meta.pkl containing the\n",
"encoder and decoder and some other related info.\n",
"\"\"\"\n",
"import os\n",
"import pickle\n",
"import requests\n",
"import numpy as np\n",
"\n",
"# download the tiny shakespeare dataset\n",
"data_dir = './data'\n",
"if not os.path.isdir(data_dir):\n",
" os.mkdir(data_dir)\n",
"input_file_path = os.path.join(data_dir, 'input.txt')\n",
"if not os.path.exists(input_file_path):\n",
" data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'\n",
" with open(input_file_path, 'w') as f:\n",
" f.write(requests.get(data_url).text)\n",
"\n",
"with open(input_file_path, 'r') as f:\n",
" data = f.read()\n",
"print(f\"length of dataset in characters: {len(data):,}\")\n",
"\n",
"# get all the unique characters that occur in this text\n",
"chars = sorted(list(set(data)))\n",
"vocab_size = len(chars)\n",
"print(\"all the unique characters:\", ''.join(chars))\n",
"print(f\"vocab size: {vocab_size:,}\")\n",
"\n",
"# create a mapping from characters to integers\n",
"stoi = { ch:i for i,ch in enumerate(chars) }\n",
"itos = { i:ch for i,ch in enumerate(chars) }\n",
"def encode(s):\n",
" return [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
"def decode(l):\n",
" return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
"\n",
"# create the train and test splits\n",
"n = len(data)\n",
"train_data = data[:int(n*0.9)]\n",
"val_data = data[int(n*0.9):]\n",
"\n",
"# encode both to integers\n",
"train_ids = encode(train_data)\n",
"val_ids = encode(val_data)\n",
"print(f\"train has {len(train_ids):,} tokens\")\n",
"print(f\"val has {len(val_ids):,} tokens\")\n",
"\n",
"# export to bin files\n",
"train_ids = np.array(train_ids, dtype=np.uint16)\n",
"val_ids = np.array(val_ids, dtype=np.uint16)\n",
"train_ids.tofile(os.path.join(data_dir, 'train.bin'))\n",
"val_ids.tofile(os.path.join(data_dir, 'val.bin'))\n",
"\n",
"# save the meta information as well, to help us encode/decode later\n",
"meta = {\n",
" 'vocab_size': vocab_size,\n",
" 'itos': itos,\n",
" 'stoi': stoi,\n",
"}\n",
"with open(os.path.join(data_dir, 'meta.pkl'), 'wb') as f:\n",
" pickle.dump(meta, f)\n",
"\n",
"# length of dataset in characters: 1115394\n",
"# all the unique characters:\n",
"# !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
"# vocab size: 65\n",
"# train has 1003854 tokens\n",
"# val has 111540 tokens"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gwue7RgXNarJ",
"outputId": "1499ce02-9e7f-493b-da2d-325ddbe59b4f"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"length of dataset in characters: 1,115,394\n",
"all the unique characters: \n",
" !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
"vocab size: 65\n",
"train has 1,003,854 tokens\n",
"val has 111,540 tokens\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"\n",
"import numpy as np\n",
"import os\n",
"import torch\n",
"import yaml\n",
"\n",
"from pytfex.transformer.mask import get_causal_mask\n",
"from torch.optim.lr_scheduler import ExponentialLR\n",
"\n",
"data_dir = 'data'\n",
"block_size = 256\n",
"batch_size = 64+32\n",
"device = 'cuda'\n",
"# device = 'cpu'\n",
"\n",
"train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')\n",
"val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')\n",
"\n",
"def get_batch(split):\n",
" data = train_data if split == 'train' else val_data\n",
" ix = torch.randint(len(data) - block_size, (batch_size,))\n",
" x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])\n",
" y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])\n",
" x, y = x.to(device), y.to(device)\n",
" return x, y\n",
"\n",
"def validate(model):\n",
" total = 0\n",
" sum_acc = 0\n",
" mask = get_causal_mask(block_size).to(device)\n",
" for _ in range(3):\n",
" x, y_true = get_batch('val')\n",
" y_pred = model(x, mask=mask)\n",
" r = torch.eq(y_true, y_pred.argmax(dim=-1))\n",
" b, l = r.shape\n",
" total += b*l\n",
" sum_acc += r.sum()\n",
" acc = sum_acc / total\n",
" return acc\n",
"\n"
],
"metadata": {
"id": "qW7GwKWphn0c"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from pytfex.models import get_model, GPTTokenChoiceMoEConfig\n",
"\n",
"config = GPTTokenChoiceMoEConfig(\n",
" num_layers=6,\n",
" num_experts=8,\n",
" k=2,\n",
" hdn_dim=1024,\n",
" mlp_hdn_dim=int(1024),\n",
" batch_size=batch_size,\n",
" dropout=0.01,\n",
" num_heads=16\n",
")\n",
"\n",
"model = get_model(config)\n",
"model.to(device)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xuNofb90h_8l",
"outputId": "7e2238f7-c276-4f70-bc6a-97d433b50147"
},
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GPT(\n",
" (drop): Dropout(p=0.01, inplace=False)\n",
" (embedder): MultiEmbedder(\n",
" (embedders): ModuleList(\n",
" (0): TokenEmbedder(\n",
" (tok_emb): Embedding(65, 1024)\n",
" )\n",
" (1): PositionEmbedder(\n",
" (pos_emb): Embedding(256, 1024)\n",
" )\n",
" )\n",
" )\n",
" (layers): ModuleList(\n",
" (0-5): 6 x TransformerLayer(\n",
" (attn): Attention(\n",
" (attn_dropout): Dropout(p=0.009999999776482582, inplace=False)\n",
" (resid_dropout): Dropout(p=0.009999999776482582, inplace=False)\n",
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
" (linear): Linear(in_features=1024, out_features=1024, bias=True)\n",
" )\n",
" (mlp): TokenChoiceMoE(\n",
" (experts): ModuleList(\n",
" (0-7): 8 x MLP(\n",
" (mlp_dropout): Dropout(p=0.009999999776482582, inplace=False)\n",
" (linear1): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (linear2): Linear(in_features=1024, out_features=1024, bias=True)\n",
" )\n",
" )\n",
" (gate): Linear(in_features=1024, out_features=8, bias=False)\n",
" )\n",
" (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" (head): ClassificationHead(\n",
" (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" (linear): Linear(in_features=1024, out_features=65, bias=False)\n",
" )\n",
")"
]
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [
"from pytfex.utils import count_parameters\n",
"\n",
"count_parameters(model)"
],
"metadata": {
"id": "0iOwty520VIw",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "de8e0167-90df-4ad0-c8d7-87a3058cfba8"
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"126423040"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"75974656"
],
"metadata": {
"id": "KUK5nG0XGqs0",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "383438a3-e142-485a-9132-bddd405da1c8"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"75974656"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"import torch.nn as nn\n",
"\n",
"def _init_weights(module):\n",
" if isinstance(module, (nn.Linear, nn.Embedding)):\n",
" module.weight.data.normal_(mean=0.0, std=0.02)\n",
" if isinstance(module, nn.Linear) and module.bias is not None:\n",
" module.bias.data.zero_()\n",
" elif isinstance(module, nn.LayerNorm):\n",
" module.bias.data.zero_()\n",
" module.weight.data.fill_(1.0)\n",
"\n",
"model.apply(_init_weights)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1jtKtIisTbH7",
"outputId": "4af1e72b-6ba1-461a-f27f-3e59ea5e063d"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GPT(\n",
" (drop): Dropout(p=0.01, inplace=False)\n",
" (embedder): MultiEmbedder(\n",
" (embedders): ModuleList(\n",
" (0): TokenEmbedder(\n",
" (tok_emb): Embedding(65, 1024)\n",
" )\n",
" (1): PositionEmbedder(\n",
" (pos_emb): Embedding(256, 1024)\n",
" )\n",
" )\n",
" )\n",
" (layers): ModuleList(\n",
" (0-5): 6 x TransformerLayer(\n",
" (attn): Attention(\n",
" (attn_dropout): Dropout(p=0.009999999776482582, inplace=False)\n",
" (resid_dropout): Dropout(p=0.009999999776482582, inplace=False)\n",
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
" (linear): Linear(in_features=1024, out_features=1024, bias=True)\n",
" )\n",
" (mlp): TokenChoiceMoE(\n",
" (experts): ModuleList(\n",
" (0-7): 8 x MLP(\n",
" (mlp_dropout): Dropout(p=0.009999999776482582, inplace=False)\n",
" (linear1): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (linear2): Linear(in_features=1024, out_features=1024, bias=True)\n",
" )\n",
" )\n",
" (gate): Linear(in_features=1024, out_features=8, bias=False)\n",
" )\n",
" (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" (head): ClassificationHead(\n",
" (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" (linear): Linear(in_features=1024, out_features=65, bias=False)\n",
" )\n",
")"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"import math\n",
"\n",
"learning_rate = 1e-3 # with baby networks can afford to go a bit higher\n",
"max_iters = 2500\n",
"lr_decay_iters = 2500 # make equal to max_iters usually\n",
"min_lr = 1e-4 # learning_rate / 10 usually\n",
"beta2 = 0.99 # make a bit bigger because number of tokens per iter is small\n",
"\n",
"warmup_iters = 100 # not super necessary potentially\n",
"# learning rate decay scheduler (cosine with warmup)\n",
"def get_lr(it):\n",
" # 1) linear warmup for warmup_iters steps\n",
" if it < warmup_iters:\n",
" return learning_rate * it / warmup_iters\n",
" # 2) if it > lr_decay_iters, return min learning rate\n",
" if it > lr_decay_iters:\n",
" return min_lr\n",
" # 3) in between, use cosine decay down to min learning rate\n",
" decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)\n",
" assert 0 <= decay_ratio <= 1\n",
" coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1\n",
" return min_lr + coeff * (learning_rate - min_lr)"
],
"metadata": {
"id": "Nic2IGjJBOOM"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from torch.optim import AdamW\n",
"\n",
"opt = AdamW(model.get_parameters(weight_decay=0.1), lr=0.001)"
],
"metadata": {
"id": "3yiPL7ECiI7U"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from livelossplot import PlotLosses\n",
"\n",
"plotlosses = PlotLosses(\n",
" groups={\n",
" 'loss': ['loss'],\n",
" 'val_acc': ['val_acc'],\n",
" 'lr': ['lr']\n",
" }\n",
" )"
],
"metadata": {
"id": "0WAda9maHLt8"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from torch.nn import functional as F\n",
"import time\n",
"\n",
"torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
"acc = validate(model)\n",
"mask = get_causal_mask(block_size).to(device)\n",
"\n",
"history = []\n",
"for epoch in range(int(2500)):\n",
" ts = time.time()\n",
" lr = get_lr(epoch)\n",
" for param_group in opt.param_groups:\n",
" param_group['lr'] = lr\n",
"\n",
" opt.zero_grad()\n",
" x, y_true = get_batch('train')\n",
" logits = model(x, mask=mask)\n",
" loss = F.cross_entropy(\n",
" logits.view(-1, logits.size(-1)),\n",
" y_true.view(-1), ignore_index=-1\n",
" )\n",
" loss.backward()\n",
" opt.step()\n",
"\n",
" if (epoch % 25) == 0:\n",
" acc = validate(model)\n",
" data = {\n",
" 'loss': loss.detach().cpu().item(),\n",
" 'val_acc': acc.cpu().item(),\n",
" 'lr': lr\n",
" }\n",
" plotlosses.update(data)\n",
" plotlosses.send()\n",
" data['ts'] = ts\n",
" history.append(data)\n",
"\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 831
},
"id": "UKza08rUNVHw",
"outputId": "99fb77b3-234b-4626-fe67-9035278ffb43"
},
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x1200 with 4 Axes>"
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loss\n",
"\tloss \t (min: 0.078, max: 4.155, cur: 0.078)\n",
"lr\n",
"\tlr \t (min: 0.000, max: 0.001, cur: 0.000)\n",
"val_acc\n",
"\tval_acc \t (min: 0.025, max: 0.548, cur: 0.543)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"root = './drive/MyDrive/transformer-experiments'"
],
"metadata": {
"id": "CRKjd9o8romF"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"if not os.path.isdir(f'{root}/moe-tc'):\n",
" os.mkdir(f'{root}/moe-tc')\n",
"model.save_state(os.path.join(f'{root}/moe-tc', 'model_state.pt'))"
],
"metadata": {
"id": "FOIvcB0tDG77"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import json\n",
"\n",
"with open(f'{root}/moe-tc/history.json', 'w') as history_file:\n",
" history_file.write(json.dumps(history))"
],
"metadata": {
"id": "ARKA4CQEkWL2"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"root = './drive/MyDrive/transformer-experiments'\n",
"model.load_state(os.path.join(f'{root}/moe-tc', 'model_state.pt'))"
],
"metadata": {
"id": "PNp1s4UrY3I-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import pickle\n",
"\n",
"meta_path = os.path.join(data_dir, 'meta.pkl')\n",
"meta_vocab_size = None\n",
"if os.path.exists(meta_path):\n",
" with open(meta_path, 'rb') as f:\n",
" meta = pickle.load(f)\n",
" meta_vocab_size = meta['vocab_size']\n",
" print(f\"found vocab_size = {meta_vocab_size} (inside {meta_path})\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "so9W-VqQD4Ag",
"outputId": "ad23a51a-5884-4292-c8b1-247bea630dbb"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"found vocab_size = 65 (inside data/meta.pkl)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from tqdm import tqdm\n",
"\n",
"\n",
"def decode(model, text, temp=1, limit=16, sample=True):\n",
" input_ids = torch.tensor([meta['stoi'][char] for char in text])[None]\n",
"\n",
" if torch.cuda.is_available():\n",
" input_ids = input_ids.cuda()\n",
"\n",
" result = text\n",
" for _ in tqdm(range(limit)):\n",
" mask = get_causal_mask(len(input_ids)).to(device)\n",
" preds = model(input_ids, mask=mask)\n",
" y = (preds[:, -1, :] / temp).softmax(dim=-1)\n",
" if sample:\n",
" next_token = torch.multinomial(y, 1)\n",
" else:\n",
" next_token = torch.argmax(y, dim=-1)\n",
" result += meta['itos'][next_token.item()]\n",
" if not sample: next_token = next_token[None]\n",
" input_ids = torch.cat((input_ids, next_token), dim=-1)\n",
"\n",
" return result\n"
],
"metadata": {
"id": "Vd97wNDzOdxw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"text = \"Who \"\n",
"print(decode(model, text, temp=0.7, limit=200))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Apkc2hUWRZz_",
"outputId": "0f020368-90ff-460e-ab88-659a5b8797e2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 200/200 [00:13<00:00, 15.22it/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Who know ?\n",
"\n",
"BRSH:\n",
"Whither nothink you were he that one but better need I was ere he, my liege, my liege.\n",
"\n",
"BUCKINGHAM:\n",
"You confess it not, the lorder I rest malice yet is for the rest for Engly the rest an\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"validate(model)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NbnQd_2_RbO7",
"outputId": "d33417ed-fe9f-421f-afa7-8bed375f82c1"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(0.5412, device='cuda:0')"
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "aNkzn6D6Mm6D"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment