Skip to content

Instantly share code, notes, and snippets.

@Maximilian-Winter
Created January 10, 2023 16:18
Show Gist options
  • Save Maximilian-Winter/1ce3a32a6dfe30aee82cecd3c0ae0d94 to your computer and use it in GitHub Desktop.
Save Maximilian-Winter/1ce3a32a6dfe30aee82cecd3c0ae0d94 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "3c50ef8c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"^C\n",
"Requirement already satisfied: datasets in c:\\users\\maxim\\anaconda3\\lib\\site-packages (2.8.0)\n",
"Requirement already satisfied: pyarrow>=6.0.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (10.0.1)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (0.11.1)\n",
"Requirement already satisfied: pandas in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (1.4.4)\n",
"Requirement already satisfied: packaging in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (21.3)\n",
"Requirement already satisfied: responses<0.19 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (0.18.0)\n",
"Requirement already satisfied: multiprocess in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (0.70.14)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (2022.7.1)\n",
"Requirement already satisfied: dill<0.3.7 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (0.3.6)\n",
"Requirement already satisfied: aiohttp in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (3.8.3)\n",
"Requirement already satisfied: tqdm>=4.62.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (4.64.1)\n",
"Requirement already satisfied: requests>=2.19.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (2.28.1)\n",
"Requirement already satisfied: xxhash in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (3.2.0)\n",
"Requirement already satisfied: numpy>=1.17 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (1.21.5)\n",
"Requirement already satisfied: pyyaml>=5.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (6.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.3.0)\n",
"Requirement already satisfied: filelock in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.6.0)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from packaging->datasets) (3.0.9)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (3.3)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (2.0.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (1.26.11)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (2022.9.14)\n",
"Requirement already satisfied: colorama in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from tqdm>=4.62.1->datasets) (0.4.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.8.2)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.3.3)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (4.0.2)\n",
"Requirement already satisfied: attrs>=17.3.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (21.4.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from pandas->datasets) (2022.1)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n"
]
}
],
"source": [
"!pip install --pre torch torchtext --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu117\n",
"!pip install datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40ac2394",
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"\n",
"\n",
"def truncate_string(string, length):\n",
" if len(string) > length:\n",
" return string[:length]\n",
" else:\n",
" return string\n",
"\n",
"\n",
"def remove_docstring_from_python_string(code):\n",
" # Use a regular expression to find all comments and docstrings\n",
" comments_and_docstrings = re.compile(r\"(\\\"\\\"\\\".*?\\\"\\\"\\\")|('''.*?''')\", re.DOTALL)\n",
"\n",
" # Remove the comments and docstrings from the code\n",
" cleaned_code = comments_and_docstrings.sub(\"\", code)\n",
"\n",
" return cleaned_code\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0244a970",
"metadata": {},
"outputs": [],
"source": [
"import csv\n",
"from datasets import load_dataset\n",
"\n",
"\n",
"def save_list_of_tuples_to_csv(filename, list_of_tuples, x_header='x', y_header='y'):\n",
" with open(filename, 'w', newline='', encoding='utf-8') as csvfile:\n",
" csv.QUOTE_ALL = True\n",
" writer = csv.writer(csvfile, delimiter=',', quotechar='|', escapechar='%')\n",
" writer.writerow([x_header, y_header])\n",
" for x, y in list_of_tuples:\n",
" writer.writerow([x, y])\n",
"\n",
"\n",
"def load_list_of_tuples_from_csv(filename):\n",
" dataset = []\n",
" with open(filename, newline='', encoding='utf-8') as csvfile:\n",
" csv.QUOTE_ALL = True\n",
" reader = csv.reader(csvfile, delimiter=',', quotechar='|', escapechar='%')\n",
" next(reader) # Skip the header row\n",
" for row in reader:\n",
" if len(row) > 1:\n",
" x = row[0]\n",
" y = row[1]\n",
" dataset.append((x, y))\n",
" return dataset\n",
"\n",
"\n",
"def generate_csv_dataset_from_huggingface(dataset, dataset_filter, x_header,\n",
" y_header, train_count, validation_count, test_count, train_csv_filename,\n",
" validation_csv_filename, test_csv_filename,\n",
" prepare_x_function=None, prepare_y_function=None):\n",
" if train_count > 0:\n",
" train_data = []\n",
" train_dataset_iter = load_dataset(dataset, dataset_filter, streaming=True, split=\"train\")\n",
" t = 0\n",
" for i in iter(train_dataset_iter):\n",
" if t > train_count:\n",
" break\n",
"\n",
" if prepare_x_function is not None:\n",
" x = prepare_x_function(i[x_header])\n",
" else:\n",
" x = i[x_header]\n",
"\n",
" if prepare_y_function is not None:\n",
" y = prepare_y_function(i[y_header])\n",
" else:\n",
" y = i[y_header]\n",
"\n",
" train_data.append((x, y))\n",
" t += 1\n",
"\n",
" save_list_of_tuples_to_csv(\n",
" train_csv_filename, train_data, x_header, y_header)\n",
"\n",
" if validation_count > 0:\n",
" validation_data = []\n",
" validation_dataset_iter = load_dataset(dataset, dataset_filter, streaming=True, split=\"validation\")\n",
" t = 0\n",
" for i in iter(validation_dataset_iter):\n",
" if t > validation_count:\n",
" break\n",
"\n",
" if prepare_x_function is not None:\n",
" x = prepare_x_function(i[x_header])\n",
" else:\n",
" x = i[x_header]\n",
"\n",
" if prepare_y_function is not None:\n",
" y = prepare_y_function(i[y_header])\n",
" else:\n",
" y = i[y_header]\n",
"\n",
" validation_data.append((x, y))\n",
" t += 1\n",
"\n",
" save_list_of_tuples_to_csv(\n",
" validation_csv_filename, validation_data, x_header,\n",
" y_header)\n",
"\n",
" if test_count > 0:\n",
" test_data = []\n",
" test_dataset_iter = load_dataset(dataset, dataset_filter, streaming=True, split=\"test\")\n",
" t = 0\n",
" for i in iter(test_dataset_iter):\n",
" if t > test_count:\n",
" break\n",
"\n",
" if prepare_x_function is not None:\n",
" x = prepare_x_function(i[x_header])\n",
" else:\n",
" x = i[x_header]\n",
"\n",
" if prepare_y_function is not None:\n",
" y = prepare_y_function(i[y_header])\n",
" else:\n",
" y = i[y_header]\n",
"\n",
" test_data.append((x, y))\n",
" t += 1\n",
"\n",
" save_list_of_tuples_to_csv(\n",
" test_csv_filename, test_data, x_header, y_header)\n",
"\n",
"\n",
"def load_csv_data_truncated(train_csv_filename, validation_csv_filename, test_csv_filename, max_sequence_length):\n",
" train_data = load_list_of_tuples_from_csv(train_csv_filename)\n",
" validation_data = load_list_of_tuples_from_csv(validation_csv_filename)\n",
" test_data = load_list_of_tuples_from_csv(test_csv_filename)\n",
" cleaned_train_data = []\n",
"\n",
" for desc, code in train_data:\n",
" description = truncate_string(desc, max_sequence_length)\n",
" function = truncate_string(code, max_sequence_length)\n",
" cleaned_train_data.append((description, function))\n",
"\n",
" cleaned_validation_data = []\n",
"\n",
" for desc, code in validation_data:\n",
" description = truncate_string(desc, max_sequence_length)\n",
" function = truncate_string(code, max_sequence_length)\n",
" cleaned_validation_data.append((description, function))\n",
"\n",
" cleaned_test_data = []\n",
"\n",
" for desc, code in test_data:\n",
" description = truncate_string(desc, max_sequence_length)\n",
" function = truncate_string(code, max_sequence_length)\n",
" cleaned_test_data.append((description, function))\n",
"\n",
" return cleaned_train_data, cleaned_validation_data, cleaned_test_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6762bce7",
"metadata": {},
"outputs": [],
"source": [
"from typing import Tuple, List\n",
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class PyTorchDataset(Dataset):\n",
" def __init__(self, dataset: List[Tuple], **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.dataset = dataset\n",
"\n",
" def __len__(self):\n",
" return len(self.dataset)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.dataset[idx][0], self.dataset[idx][1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64c785fd",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import math\n",
"\n",
"import torch\n",
"from torch import nn as nn, Tensor\n",
"from torch.nn import Transformer\n",
"\n",
"\n",
"class PositionalEncoding(nn.Module):\n",
" def __init__(self,\n",
" emb_size: int,\n",
" dropout: float,\n",
" maxlen: int = 5000):\n",
" super(PositionalEncoding, self).__init__()\n",
" den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)\n",
" pos = torch.arange(0, maxlen).reshape(maxlen, 1)\n",
" pos_embedding = torch.zeros((maxlen, emb_size))\n",
" pos_embedding[:, 0::2] = torch.sin(pos * den)\n",
" pos_embedding[:, 1::2] = torch.cos(pos * den)\n",
" pos_embedding = pos_embedding.unsqueeze(-2)\n",
"\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('pos_embedding', pos_embedding)\n",
"\n",
" def forward(self, token_embedding: Tensor):\n",
" return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])\n",
"\n",
"\n",
"class TokenEmbedding(nn.Module):\n",
" def __init__(self, vocab_size: int, emb_size):\n",
" super(TokenEmbedding, self).__init__()\n",
" self.embedding = nn.Embedding(vocab_size, emb_size)\n",
" self.emb_size = emb_size\n",
"\n",
" def forward(self, tokens: Tensor):\n",
" return self.embedding(tokens.long()) * math.sqrt(self.emb_size)\n",
"\n",
"\n",
"class Seq2SeqTransformer(nn.Module):\n",
" def __init__(self,\n",
" num_encoder_layers: int,\n",
" num_decoder_layers: int,\n",
" d_model: int,\n",
" num_heads: int,\n",
" src_vocab_size: int,\n",
" tgt_vocab_size: int,\n",
" dim_feedforward: int,\n",
" dropout: float = 0.1):\n",
" super(Seq2SeqTransformer, self).__init__()\n",
" self.transformer = Transformer(d_model=d_model,\n",
" nhead=num_heads,\n",
" num_encoder_layers=num_encoder_layers,\n",
" num_decoder_layers=num_decoder_layers,\n",
" dim_feedforward=dim_feedforward,\n",
" dropout=dropout)\n",
" self.generator = nn.Linear(d_model, tgt_vocab_size)\n",
" self.src_tok_emb = TokenEmbedding(src_vocab_size, d_model)\n",
" self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, d_model)\n",
" self.positional_encoding = PositionalEncoding(\n",
" d_model, dropout=dropout)\n",
"\n",
" def forward(self,\n",
" src: Tensor,\n",
" trg: Tensor,\n",
" src_mask: Tensor,\n",
" tgt_mask: Tensor,\n",
" src_padding_mask: Tensor,\n",
" tgt_padding_mask: Tensor,\n",
" memory_key_padding_mask: Tensor):\n",
" src_emb = self.positional_encoding(self.src_tok_emb(src))\n",
" tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))\n",
" outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,\n",
" src_padding_mask, tgt_padding_mask, memory_key_padding_mask)\n",
" return self.generator(outs)\n",
"\n",
" def encode(self, src: Tensor, src_mask: Tensor):\n",
" return self.transformer.encoder(self.positional_encoding(\n",
" self.src_tok_emb(src)), src_mask)\n",
"\n",
" def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):\n",
" return self.transformer.decoder(self.positional_encoding(\n",
" self.tgt_tok_emb(tgt)), memory, tgt_mask)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef96ce13",
"metadata": {},
"outputs": [],
"source": [
"from torchtext.data import get_tokenizer\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"import pickle\n",
"from typing import Iterable, List\n",
"import tokenize\n",
"from io import BytesIO\n",
"\n",
"\n",
"def tokenize_code(code_string):\n",
" code_string = code_string + \"\\n\"\n",
" code_string_bytes = code_string.encode()\n",
" code = BytesIO(code_string_bytes)\n",
" tokens = []\n",
" i = 0\n",
" try:\n",
" for token_info in tokenize.tokenize(code.readline):\n",
" token_type = token_info[0]\n",
" token_string = token_info[1]\n",
" if i == 0:\n",
" if token_string == 'utf-8':\n",
" continue\n",
" if token_type == tokenize.NEWLINE:\n",
" tokens.append('NEWLINE')\n",
" elif token_type == tokenize.INDENT:\n",
" tokens.append('INDENT')\n",
" elif token_type == tokenize.DEDENT:\n",
" tokens.append('DEDENT')\n",
" else:\n",
" # This is a regular token\n",
" tokens.append(token_string)\n",
" i += 1\n",
" except tokenize.TokenError:\n",
" return tokens\n",
"\n",
" return tokens\n",
"\n",
"\n",
"def load_token_and_vocab_transform(token_transform_filename, vocab_transform_filename):\n",
" return pickle.load(open(token_transform_filename, \"rb\")), pickle.load(\n",
" open(vocab_transform_filename, \"rb\"))\n",
"\n",
"\n",
"def save_token_and_vocab_transform(token_transform, vocab_transform, token_transform_filename,\n",
" vocab_transform_filename):\n",
" pickle.dump(token_transform, open(token_transform_filename, \"wb\"))\n",
" pickle.dump(vocab_transform, open(vocab_transform_filename, \"wb\"))\n",
"\n",
"\n",
"def get_token_and_vocab_transform_and_special_token_ids(src_language, tgt_language, data_iterator):\n",
" vocab_transform = {}\n",
" token_transform = {src_language: get_tokenizer('spacy', language='en_core_web_sm'),\n",
" tgt_language: tokenize_code}\n",
"\n",
" # helper function to yield list of tokens\n",
" def yield_tokens(data_iter: Iterable, language: str) -> List[str]:\n",
" language_index = {src_language: 0, tgt_language: 1}\n",
"\n",
" for data_sample in data_iter:\n",
" yield token_transform[language](data_sample[language_index[language]])\n",
"\n",
" UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3\n",
" # Make sure the tokens are in order of their indices to properly insert them in vocab\n",
" special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']\n",
"\n",
" for ln in [src_language, tgt_language]:\n",
" # Create torchtext's Vocab object\n",
" vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(data_iterator, ln),\n",
" min_freq=1,\n",
" specials=special_symbols,\n",
" special_first=True)\n",
"\n",
" for ln in [src_language, tgt_language]:\n",
" vocab_transform[ln].set_default_index(UNK_IDX)\n",
"\n",
" return token_transform, vocab_transform, UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX\n",
"\n",
"\n",
"# helper function to club together sequential operations\n",
"def sequential_transforms(*transforms):\n",
" def func(txt_input):\n",
" for transform in transforms:\n",
" txt_input = transform(txt_input)\n",
" return txt_input\n",
"\n",
" return func\n",
"\n",
"\n",
"def ConvertToTokenIds(dataset, x_text_transform, y_text_transform):\n",
" new_dataset = []\n",
" for src_sample, tgt_sample in dataset:\n",
" new_dataset.append((x_text_transform(src_sample.rstrip(\"\\n\")), y_text_transform(tgt_sample.rstrip(\"\\n\"))))\n",
"\n",
" return new_dataset\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c44182e7",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from timeit import default_timer as timer\n",
"from typing import List\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"\n",
"torch.manual_seed(42)\n",
"DEVICE = torch.device('cuda')\n",
"\n",
"SRC_LANGUAGE = 'description'\n",
"TGT_LANGUAGE = 'code'\n",
"\n",
"generate_csv_dataset = True\n",
"max_sequence_length = 64000\n",
"\n",
"train_count = 100000\n",
"validation_count = 1000\n",
"test_count = 500\n",
"\n",
"train_csv_filename = f'./code_search_net_train_{train_count}.csv'\n",
"validation_csv_filename = f'./code_search_net_validation_{validation_count}.csv'\n",
"test_csv_filename = f'./code_search_net_test_{test_count}.csv'\n",
"\n",
"if generate_csv_dataset:\n",
" generate_csv_dataset_from_huggingface(dataset=\"code_search_net\", dataset_filter=\"python\",\n",
" x_header='func_documentation_string',\n",
" y_header='func_code_string', train_count=train_count,\n",
" validation_count=validation_count, test_count=test_count,\n",
" train_csv_filename=train_csv_filename,\n",
" validation_csv_filename=validation_csv_filename,\n",
" test_csv_filename=test_csv_filename,\n",
" prepare_x_function=None,\n",
" prepare_y_function=remove_docstring_from_python_string)\n",
"\n",
"truncated_train_data, truncated_validation_data, truncated_test_data = load_csv_data_truncated(train_csv_filename,\n",
" validation_csv_filename,\n",
" test_csv_filename,\n",
" max_sequence_length)\n",
"\n",
"train_dataset_for_vocab_and_token_transform_generation = PyTorchDataset(truncated_train_data)\n",
"\n",
"token_transform, vocab_transform, UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = get_token_and_vocab_transform_and_special_token_ids(\n",
" SRC_LANGUAGE, TGT_LANGUAGE, train_dataset_for_vocab_and_token_transform_generation)\n",
"\n",
"# for src_sample, tgt_sample in train_dataset_for_vocab_and_token_transform_generation:\n",
"# print((token_transform[SRC_LANGUAGE](src_sample), token_transform[TGT_LANGUAGE](tgt_sample)))\n",
"\n",
"\n",
"# Add BOS/EOS and create tensor for input sequence indices\n",
"def tensor_transform(token_ids: List[int]):\n",
" return torch.cat((torch.tensor([BOS_IDX]),\n",
" torch.tensor(token_ids),\n",
" torch.tensor([EOS_IDX])))\n",
"\n",
"\n",
"# Text transforms to convert raw strings into tensors indices\n",
"text_transform = {}\n",
"for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:\n",
" text_transform[ln] = sequential_transforms(token_transform[ln], # Tokenization\n",
" vocab_transform[ln], # Numericalization\n",
" tensor_transform) # Add BOS/EOS and create tensor\n",
"\n",
"truncated_train_data = ConvertToTokenIds(truncated_train_data, text_transform[SRC_LANGUAGE],\n",
" text_transform[TGT_LANGUAGE])\n",
"truncated_validation_data = ConvertToTokenIds(truncated_validation_data, text_transform[SRC_LANGUAGE],\n",
" text_transform[TGT_LANGUAGE])\n",
"truncated_test_data = ConvertToTokenIds(truncated_test_data, text_transform[SRC_LANGUAGE], text_transform[TGT_LANGUAGE])\n",
"\n",
"\n",
"def generate_square_subsequent_mask(sz):\n",
" mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)\n",
" mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n",
" return mask\n",
"\n",
"\n",
"# Create attention masks.\n",
"def create_mask(src, tgt):\n",
" src_seq_len = src.shape[0]\n",
" tgt_seq_len = tgt.shape[0]\n",
"\n",
" tgt_mask = generate_square_subsequent_mask(tgt_seq_len)\n",
" src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)\n",
"\n",
" src_padding_mask = (src == PAD_IDX).transpose(0, 1)\n",
" tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)\n",
" return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask\n",
"\n",
"\n",
"# Collate data samples into batch tensors\n",
"def collate_fn(batch):\n",
" src_batch, tgt_batch = [], []\n",
" for src_sample, tgt_sample in batch:\n",
" src_batch.append(src_sample)\n",
" tgt_batch.append(tgt_sample)\n",
"\n",
" src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)\n",
" tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)\n",
" return src_batch, tgt_batch\n",
"\n",
"\n",
"def perform_training(model, optimizer, loss_fn, batch_size):\n",
" model.train()\n",
" losses = 0\n",
" train_iter = PyTorchDataset(truncated_train_data)\n",
" train_dataloader = DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)\n",
" with tqdm(total=int(len(train_iter) / batch_size)) as pbar:\n",
" for src, tgt in train_dataloader:\n",
" src = src.to(DEVICE)\n",
" src = src.long()\n",
" tgt = tgt.to(DEVICE)\n",
" tgt = tgt.long()\n",
" tgt_input = tgt[:-1, :]\n",
"\n",
" src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)\n",
"\n",
" logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" tgt_out = tgt[1:, :]\n",
" loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" losses += loss.item()\n",
" pbar.update(1)\n",
"\n",
" return losses / len(train_dataloader)\n",
"\n",
"\n",
"def perform_validation(model, loss_fn, batch_size):\n",
" model.eval()\n",
" losses = 0\n",
" val_iter = PyTorchDataset(truncated_validation_data)\n",
" val_dataloader = DataLoader(val_iter, collate_fn=collate_fn, batch_size=batch_size)\n",
" with tqdm(total=int(len(val_iter) / batch_size)) as pbar:\n",
" for src, tgt in val_dataloader:\n",
" src = src.to(DEVICE)\n",
" src = src.long()\n",
" tgt = tgt.to(DEVICE)\n",
" tgt = tgt.long()\n",
"\n",
" tgt_input = tgt[:-1, :]\n",
"\n",
" src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)\n",
"\n",
" logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)\n",
"\n",
" tgt_out = tgt[1:, :]\n",
" loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))\n",
" losses += loss.item()\n",
" pbar.update(1)\n",
" return losses / len(val_dataloader)\n",
"\n",
"\n",
"def lr_scheduler(step_num, d_model, warmup_steps=4000):\n",
" if step_num == 0:\n",
" return d_model ** -0.5\n",
" # Linearly increasing the learning rate for the first warmup_steps, and\n",
" # decreasing it thereafter\n",
" arg1 = step_num ** -0.5\n",
" arg2 = step_num * (warmup_steps ** -1.5)\n",
"\n",
" return (d_model ** -0.5) * min(arg1, arg2)\n",
"\n",
"\n",
"def accuracy_fcn(target, prediction):\n",
" # Find equal prediction and target values, and apply the padding mask\n",
" accuracy = (target == torch.argmax(prediction, dim=2)).float()\n",
" mask = (target != 0).float()\n",
"\n",
" return torch.sum(accuracy * mask) / torch.sum(mask)\n",
"\n",
"\n",
"def fit_transformer_model(transformer_model, num_epochs, batch_size):\n",
" loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)\n",
" optimizer = torch.optim.NAdam(transformer_model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 )\n",
" # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: lr_scheduler(step, d_model))\n",
"\n",
" for epoch in range(1, num_epochs + 1):\n",
" start_time = timer()\n",
" train_loss = perform_training(transformer_model, optimizer, loss_fn, batch_size)\n",
" end_time = timer()\n",
" val_loss = perform_validation(transformer_model, loss_fn, batch_size)\n",
" print((\n",
" f\"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, \"f\"Epoch time = {(end_time - start_time):.3f}s\"))\n",
" if epoch % 5 == 0:\n",
" torch.save(transformer_model.state_dict(), f\"./checkpoint_at_epoch_{epoch}.pt\")\n",
" print(f\"Saved checkpoint at epoch {epoch}\")\n",
" # Update the learning rate\n",
" # scheduler.step()\n",
"\n",
"\n",
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"\n",
"# function to generate output sequence using greedy algorithm\n",
"def greedy_decode(model, src, src_mask, max_len, start_symbol):\n",
" src = src.to(DEVICE)\n",
" src_mask = src_mask.to(DEVICE)\n",
"\n",
" memory = model.encode(src, src_mask)\n",
" ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)\n",
" for i in range(max_len - 1):\n",
" memory = memory.to(DEVICE)\n",
" tgt_mask = (generate_square_subsequent_mask(ys.size(0))\n",
" .type(torch.bool)).to(DEVICE)\n",
" out = model.decode(ys, memory, tgt_mask)\n",
" out = out.transpose(0, 1)\n",
" prob = model.generator(out[:, -1])\n",
" _, next_word = torch.max(prob, dim=1)\n",
" next_word = next_word.item()\n",
"\n",
" ys = torch.cat([ys,\n",
" torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)\n",
" if next_word == EOS_IDX:\n",
" break\n",
" return ys\n",
"\n",
"\n",
"# actual function to translate input sentence into target language\n",
"def translate(model: torch.nn.Module, src_sentence: str):\n",
" model.eval()\n",
" src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)\n",
" num_tokens = src.shape[0]\n",
" src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)\n",
" tgt_tokens = greedy_decode(\n",
" model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()\n",
" return \" \".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace(\"\",\n",
" \"\").replace(\n",
" \"\", \"\")\n",
"\n",
"\n",
"def test_transformer(transformer_model, sentence):\n",
" print(translate(transformer_model, sentence))\n",
"\n",
"\n",
"def train_transformer_model(fit_model, load_model, test_model):\n",
" src_vocab_size = len(vocab_transform[SRC_LANGUAGE])\n",
" tgt_vocab_size = len(vocab_transform[TGT_LANGUAGE])\n",
" d_model = 512\n",
" num_heads = 8\n",
" feed_forward_dim = 2048\n",
" num_encoder_layers = 6\n",
" num_decoder_layers = 6\n",
"\n",
" batch_size = 2\n",
" num_epochs = 120\n",
"\n",
" transformer = Seq2SeqTransformer(num_encoder_layers, num_decoder_layers, d_model,\n",
" num_heads, src_vocab_size, tgt_vocab_size, feed_forward_dim)\n",
"\n",
" if load_model:\n",
" transformer.load_state_dict(torch.load(\"./checkpoint_at_epoch_25.pt\"))\n",
" else:\n",
" for p in transformer.parameters():\n",
" if p.dim() > 1:\n",
" nn.init.xavier_uniform_(p)\n",
"\n",
" print(f'The model has {count_parameters(transformer):,} trainable parameters')\n",
" opt_transformer = torch.compile(transformer, backend='inductor')\n",
" opt_transformer = opt_transformer.to(DEVICE)\n",
" if fit_model:\n",
" fit_transformer_model(transformer, num_epochs, batch_size)\n",
"\n",
" if test_model:\n",
" test_transformer(transformer, 'Downloads Dailymotion videos by URL.')\n",
"\n",
"\n",
"train_transformer_model(True, False, True)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment