Skip to content

Instantly share code, notes, and snippets.

@radekosmulski
Created August 24, 2023 23:34
Show Gist options
  • Save radekosmulski/c3cce1a52b52b9b2037e1941de5afa32 to your computer and use it in GitHub Desktop.
Save radekosmulski/c3cce1a52b52b9b2037e1941de5afa32 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "a5ccac1c",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"if not os.environ.get('TRANSFORMERS_CACHE'):\n",
" os.environ['TRANSFORMERS_CACHE'] = '/raid/transformers_cache'\n",
" \n",
"\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = \"6\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2221804b",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"from dataclasses import dataclass, field\n",
"from typing import Dict, Optional, Sequence\n",
"import warnings\n",
"\n",
"from tqdm import tqdm\n",
"from pdb import set_trace\n",
"\n",
"import torch\n",
"import numpy as np\n",
"import transformers\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "a51d944d",
"metadata": {},
"source": [
"Let's grab the dataset straight from `datasets`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4dd95b68",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset, DatasetDict\n",
"dataset = load_dataset(\"tatsu-lab/alpaca\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "549878ac",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb348acee923408aa857194ed13af371",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')\n",
"tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8d6ddb3d",
"metadata": {},
"outputs": [],
"source": [
"datasets = dataset['train'].train_test_split(test_size=2002, seed=42)\n",
"datasets = DatasetDict({'train': datasets['train'], 'valid': datasets['test']})"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "43b496fa",
"metadata": {},
"outputs": [],
"source": [
"# code from Stanford Alpaca https://github.com/tatsu-lab/stanford_alpaca\n",
"\n",
"PROMPT_DICT = {\n",
" \"prompt_input\": (\n",
" \"Below is an instruction that describes a task, paired with an input that provides further context. \"\n",
" \"Write a response that appropriately completes the request.\\n\\n\"\n",
" \"### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\"\n",
" ),\n",
" \"prompt_no_input\": (\n",
" \"Below is an instruction that describes a task. \"\n",
" \"Write a response that appropriately completes the request.\\n\\n\"\n",
" \"### Instruction:\\n{instruction}\\n\\n### Response:\"\n",
" ),\n",
"}\n",
"\n",
"def smart_tokenizer_and_embedding_resize(\n",
" special_tokens_dict: Dict,\n",
" tokenizer: transformers.PreTrainedTokenizer,\n",
" model: transformers.PreTrainedModel,\n",
"):\n",
" \"\"\"Resize tokenizer and embedding.\n",
"\n",
" Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n",
" \"\"\"\n",
" num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n",
" model.resize_token_embeddings(len(tokenizer))\n",
"\n",
" if num_new_tokens > 0:\n",
" input_embeddings = model.get_input_embeddings().weight.data\n",
" output_embeddings = model.get_output_embeddings().weight.data\n",
"\n",
" input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)\n",
" output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)\n",
"\n",
" input_embeddings[-num_new_tokens:] = input_embeddings_avg\n",
" output_embeddings[-num_new_tokens:] = output_embeddings_avg\n",
" \n",
"special_tokens_dict = dict()\n",
"special_tokens_dict[\"pad_token\"] = \"[PAD]\"\n",
"\n",
"smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ca648e3d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1, 32000]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.encode('[PAD]')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5f3b7326",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[PAD]'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode([32000])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "edfd9e7b",
"metadata": {},
"outputs": [],
"source": [
"def process_example(example):\n",
" template = PROMPT_DICT[\"prompt_input\"]\n",
" if not example['input']:\n",
" template = PROMPT_DICT[\"prompt_no_input\"] + '\\n\\n'\n",
"\n",
" prompt = template.format_map(example)\n",
" prompt_toks = tokenizer(prompt)['input_ids']\n",
" input_ids = tokenizer(prompt + example[\"output\"] + tokenizer.eos_token, return_tensors='pt')['input_ids'][0]\n",
" labels = input_ids.clone().detach()\n",
" labels[:len(prompt_toks)] = -100 # loss will not be calculated for labels set to -100\n",
" return input_ids, labels"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "36c9cd51",
"metadata": {},
"outputs": [],
"source": [
"class SupervisedDataset(Dataset):\n",
" def __init__(self, dataset):\n",
" super().__init__()\n",
" discarded_examples_count = 0\n",
" self.examples = []\n",
" for example in tqdm(dataset):\n",
" input_ids, labels = process_example(example)\n",
" if input_ids.shape[0] > 512:\n",
" discarded_examples_count += 1\n",
" else:\n",
" self.examples.append((input_ids, labels))\n",
" print(f'Discarded {discarded_examples_count} examples due to length > 512.')\n",
" \n",
" def __getitem__(self, idx):\n",
" return {\"input_ids\": self.examples[idx][0], \"labels\": self.examples[idx][1]}\n",
" def __len__(self):\n",
" return len(self.examples)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4a4bf62b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:38<00:00, 1295.40it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarded 94 examples due to length > 512.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2002/2002 [00:01<00:00, 1244.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarded 3 examples due to length > 512.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"train_ds = SupervisedDataset(datasets['train'])\n",
"valid_ds = SupervisedDataset(datasets['valid'])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "5e3628de",
"metadata": {},
"outputs": [],
"source": [
"def collate_fn(examples):\n",
" input_ids, labels = tuple([example[key] for example in examples] for key in (\"input_ids\", \"labels\"))\n",
" input_ids = torch.nn.utils.rnn.pad_sequence(\n",
" input_ids, batch_first=True, padding_value=tokenizer.pad_token_id\n",
" )\n",
" labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)\n",
" return dict(\n",
" input_ids=input_ids,\n",
" labels=labels,\n",
" attention_mask=input_ids.ne(tokenizer.pad_token_id)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a802eed6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 4,194,304 || all params: 6,742,618,112 || trainable%: 0.06220586618327525\n"
]
}
],
"source": [
"from peft import LoraConfig, TaskType\n",
"from peft import get_peft_model\n",
"\n",
"peft_config = LoraConfig(task_type=\"a_random_string\", inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)\n",
"\n",
"model = get_peft_model(model, peft_config)\n",
"model.print_trainable_parameters()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e59e53e7",
"metadata": {},
"outputs": [],
"source": [
"train_batch_size = 6\n",
"lr = 4e-4\n",
"num_epochs = 3"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "c07a2e76",
"metadata": {},
"outputs": [],
"source": [
"from accelerate import Accelerator\n",
"\n",
"accelerator = Accelerator(mixed_precision='bf16', gradient_accumulation_steps=128//train_batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "bc81f95f",
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)\n",
"valid_dl = DataLoader(valid_ds, batch_size=2*train_batch_size, shuffle=False, collate_fn=collate_fn)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "aa58ff68",
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)\n",
"lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
" optimizer,\n",
" lr,\n",
" epochs=num_epochs,\n",
" steps_per_epoch=len(train_dl)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "eb6dd55a",
"metadata": {},
"outputs": [],
"source": [
"model, train_dl, valid_dl, optimizer, lr_scheduler = accelerator.prepare(\n",
" model, train_dl, valid_dl, optimizer, lr_scheduler\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "4e622a19",
"metadata": {},
"outputs": [],
"source": [
"lrs = []\n",
"train_losses = []"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "6a88692a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch: 1\tTrain loss: 1.15: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8318/8318 [49:56<00:00, 2.78it/s]\n",
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss: 1.08\ttoken accuracy: 0.00\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch: 1\tTrain loss: 1.04: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8318/8318 [49:46<00:00, 2.78it/s]\n",
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss: 1.07\ttoken accuracy: 0.00\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch: 1\tTrain loss: 0.90: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8318/8318 [49:56<00:00, 2.78it/s]\n",
" "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss: 1.07\ttoken accuracy: 0.00\n",
"CPU times: user 2h 23s, sys: 32min 31s, total: 2h 32min 55s\n",
"Wall time: 2h 32min 36s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r"
]
}
],
"source": [
"%%time\n",
"\n",
"for i in range(num_epochs):\n",
" model.train()\n",
" pbar = tqdm(train_dl, leave=True)\n",
" for batch in pbar:\n",
" outputs = model(**batch)\n",
" loss = outputs.loss\n",
"\n",
" train_losses.append(loss.item())\n",
" lrs.append(optimizer.param_groups[0]['lr'])\n",
"\n",
" accelerator.backward(loss)\n",
" \n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" lr_scheduler.step()\n",
" pbar.set_description(f'Epoch: {1:2d}\\tTrain loss: {np.mean(train_losses[-20:]) :.2f}')\n",
"\n",
" model.eval()\n",
" preds = []\n",
" labels = []\n",
" val_losses = []\n",
" for batch in tqdm(valid_dl, leave=False):\n",
" with torch.no_grad():\n",
" outputs = model(**batch)\n",
"\n",
" logits = outputs.logits\n",
" val_losses.append(outputs.loss.item())\n",
"\n",
" preds.append(outputs.logits.argmax(-1).cpu().detach())\n",
" labels.append(batch['labels'].cpu().detach())\n",
"\n",
" hits = 0\n",
" chances = 0\n",
" for p, l in zip(preds, labels):\n",
" hits += (p == l).sum().item()\n",
" chances += (l != -100).sum().item()\n",
" print(f'Val loss: {np.mean(val_losses):3.02f}\\ttoken accuracy: {hits/chances:3.02f}')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "673395f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f8d98185a20>]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(lrs)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "da5e3cab",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f8d98c16bf0>]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(train_losses)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "a88dc232",
"metadata": {},
"outputs": [],
"source": [
"accelerator.free_memory()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "87ddc33d",
"metadata": {},
"outputs": [],
"source": [
"model.save_pretrained('/raid/models/lora_apaca_llama2')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ec39dd3",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os._exit(00)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment