Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created January 26, 2023 05:38
Show Gist options
  • Save sayakpaul/dc4321b3a41e240b9d9482f0da3db8a4 to your computer and use it in GitHub Desktop.
Save sayakpaul/dc4321b3a41e240b9d9482f0da3db8a4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e8447b79-1314-4c8c-b3ab-2888a0dc059c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/envs/py38/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import transformers \n",
"import accelerate\n",
"import peft"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "445160ea-c60f-4bad-bfef-65065616f249",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Transformers version: 4.25.1\n",
"Accelerate version: 0.15.0\n",
"PEFT version: 0.1.0.dev0\n"
]
}
],
"source": [
"print(f\"Transformers version: {transformers.__version__}\")\n",
"print(f\"Accelerate version: {accelerate.__version__}\")\n",
"print(f\"PEFT version: {peft.__version__}\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "197ff6c1-cc05-4ba8-ab6a-0dd7261df58f",
"metadata": {},
"outputs": [],
"source": [
"model_checkpoint = \"google/vit-base-patch16-224-in21k\" # pre-trained model from which to fine-tune\n",
"batch_size = 128"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e8917199-67d4-4f82-9e99-dd19bfa1b4e6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset food101 (/home/jupyter/.cache/huggingface/datasets/food101/default/0.0.0/7cebe41a80fb2da3f08fcbef769c8874073a86346f7fb96dc0847d4dfc318295)\n"
]
}
],
"source": [
"from datasets import load_dataset \n",
"\n",
"dataset = load_dataset(\"food101\", split=\"train[:5000]\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ec2bf41b-1c0f-4915-85e8-7b733c7949b8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['image', 'label'],\n",
" num_rows: 5000\n",
"})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "64e0f63d-7fc1-4688-9762-d428332e5698",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'baklava'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = dataset.features[\"label\"].names\n",
"label2id, id2label = dict(), dict()\n",
"for i, label in enumerate(labels):\n",
" label2id[label] = i\n",
" id2label[i] = label\n",
"\n",
"id2label[2]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e85b70da-902f-4506-8d1b-0fb8b85fea68",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ViTImageProcessor {\n",
" \"do_normalize\": true,\n",
" \"do_rescale\": true,\n",
" \"do_resize\": true,\n",
" \"image_mean\": [\n",
" 0.5,\n",
" 0.5,\n",
" 0.5\n",
" ],\n",
" \"image_processor_type\": \"ViTImageProcessor\",\n",
" \"image_std\": [\n",
" 0.5,\n",
" 0.5,\n",
" 0.5\n",
" ],\n",
" \"resample\": 2,\n",
" \"rescale_factor\": 0.00392156862745098,\n",
" \"size\": {\n",
" \"height\": 224,\n",
" \"width\": 224\n",
" }\n",
"}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoImageProcessor\n",
"\n",
"image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)\n",
"image_processor"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "eeec77a5-9f8c-450e-8f2f-e8e9c0027f8b",
"metadata": {},
"outputs": [],
"source": [
"from torchvision.transforms import (\n",
" CenterCrop,\n",
" Compose,\n",
" Normalize,\n",
" RandomHorizontalFlip,\n",
" RandomResizedCrop,\n",
" Resize,\n",
" ToTensor,\n",
")\n",
"\n",
"normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)\n",
"train_transforms = Compose(\n",
" [\n",
" RandomResizedCrop(image_processor.size[\"height\"]),\n",
" RandomHorizontalFlip(),\n",
" ToTensor(),\n",
" normalize,\n",
" ]\n",
" )\n",
"\n",
"val_transforms = Compose(\n",
" [\n",
" Resize(image_processor.size[\"height\"]),\n",
" CenterCrop(image_processor.size[\"height\"]),\n",
" ToTensor(),\n",
" normalize,\n",
" ]\n",
" )\n",
"\n",
"def preprocess_train(example_batch):\n",
" \"\"\"Apply train_transforms across a batch.\"\"\"\n",
" example_batch[\"pixel_values\"] = [\n",
" train_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]\n",
" ]\n",
" return example_batch\n",
"\n",
"def preprocess_val(example_batch):\n",
" \"\"\"Apply val_transforms across a batch.\"\"\"\n",
" example_batch[\"pixel_values\"] = [val_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]]\n",
" return example_batch\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fd5dee84-437e-45d2-9ccd-245000da9ae8",
"metadata": {},
"outputs": [],
"source": [
"# split up training into training + validation\n",
"splits = dataset.train_test_split(test_size=0.1)\n",
"train_ds = splits[\"train\"]\n",
"val_ds = splits[\"test\"]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f1528e46-bae5-4701-b1c1-e88be179fa06",
"metadata": {},
"outputs": [],
"source": [
"train_ds.set_transform(preprocess_train)\n",
"val_ds.set_transform(preprocess_val)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f16600d2-1fc4-4a5a-9342-bebf0ccc335a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']\n",
"- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForImageClassification, TrainingArguments, Trainer\n",
"\n",
"model = AutoModelForImageClassification.from_pretrained(\n",
" model_checkpoint, \n",
" label2id=label2id,\n",
" id2label=id2label,\n",
" ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "397a4d74-7fc4-4d42-9637-485dac28e255",
"metadata": {},
"outputs": [],
"source": [
"from peft import LoraConfig, LoraModel\n",
"\n",
"\n",
"config = LoraConfig(\n",
" r=16,\n",
" lora_alpha=16,\n",
" target_modules=[\"query\", \"value\"],\n",
" lora_dropout=0.0,\n",
" bias=\"none\",\n",
")\n",
"lora_model = LoraModel(config, model)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "80ae9b51-f17f-42cb-9499-75b3f025d13a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"589824 589824\n"
]
}
],
"source": [
"def count_trainable_params(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"\n",
"print(count_trainable_params(model), count_trainable_params(lora_model))"
]
},
{
"cell_type": "markdown",
"id": "188392c0-be72-4d6c-9d79-830203a44168",
"metadata": {},
"source": [
"__Shouldn't the number of trainable params for the `lora_model` be lesser than that of `model`?__"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e558954d-c6e0-4715-bce5-13aaf4d1d5ac",
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments, Trainer\n",
"\n",
"\n",
"model_name = model_checkpoint.split(\"/\")[-1]\n",
"\n",
"args = TrainingArguments(\n",
" f\"{model_name}-finetuned-lora-food101\",\n",
" remove_unused_columns=False,\n",
" evaluation_strategy = \"epoch\",\n",
" save_strategy = \"epoch\",\n",
" learning_rate=1e-3,\n",
" per_device_train_batch_size=batch_size,\n",
" gradient_accumulation_steps=4,\n",
" per_device_eval_batch_size=batch_size,\n",
" num_train_epochs=3,\n",
" warmup_ratio=0.1,\n",
" logging_steps=10,\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"accuracy\",\n",
" push_to_hub=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "78775cc0-ef54-4d0d-b786-17774d318f02",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import evaluate\n",
"\n",
"\n",
"metric = evaluate.load(\"accuracy\")\n",
"\n",
"# the compute_metrics function takes a Named Tuple as input:\n",
"# predictions, which are the logits of the model as Numpy arrays,\n",
"# and label_ids, which are the ground-truth labels as Numpy arrays.\n",
"def compute_metrics(eval_pred):\n",
" \"\"\"Computes accuracy on a batch of predictions\"\"\"\n",
" predictions = np.argmax(eval_pred.predictions, axis=1)\n",
" return metric.compute(predictions=predictions, references=eval_pred.label_ids)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "6412dfd1-9625-4c22-94ef-9e6d851f80f6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def collate_fn(examples):\n",
" pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
" labels = torch.tensor([example[\"label\"] for example in examples])\n",
" return {\"pixel_values\": pixel_values, \"labels\": labels}\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e976a734-393c-4db0-8fcf-8cd68c99fcfa",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jupyter/vit-base-patch16-224-in21k-finetuned-lora-food101 is already a clone of https://huggingface.co/sayakpaul/vit-base-patch16-224-in21k-finetuned-lora-food101. Make sure you pull the latest changes with `repo.git_pull()`.\n",
"/opt/conda/envs/py38/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n",
"***** Running training *****\n",
" Num examples = 4500\n",
" Num Epochs = 3\n",
" Instantaneous batch size per device = 128\n",
" Total train batch size (w. parallel, distributed & accumulation) = 512\n",
" Gradient Accumulation steps = 4\n",
" Total optimization steps = 27\n",
" Number of trainable parameters = 589824\n",
"Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33msayakpaul\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
]
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.13.9"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in <code>/home/jupyter/wandb/run-20230126_053041-gry5sv36</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run <strong><a href=\"https://wandb.ai/sayakpaul/huggingface/runs/gry5sv36\" target=\"_blank\">vit-base-patch16-224-in21k-finetuned-lora-food101</a></strong> to <a href=\"https://wandb.ai/sayakpaul/huggingface\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://wandb.me/run\" target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at <a href=\"https://wandb.ai/sayakpaul/huggingface\" target=\"_blank\">https://wandb.ai/sayakpaul/huggingface</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at <a href=\"https://wandb.ai/sayakpaul/huggingface/runs/gry5sv36\" target=\"_blank\">https://wandb.ai/sayakpaul/huggingface/runs/gry5sv36</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='27' max='27' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [27/27 06:21, Epoch 3/3]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Epoch</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>No log</td>\n",
" <td>4.580509</td>\n",
" <td>0.020000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>4.630000</td>\n",
" <td>4.255508</td>\n",
" <td>0.266000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>4.393500</td>\n",
" <td>4.014057</td>\n",
" <td>0.208000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"***** Running Evaluation *****\n",
" Num examples = 500\n",
" Batch size = 128\n",
"Saving model checkpoint to vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-9\n",
"Configuration saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-9/config.json\n",
"Model weights saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-9/pytorch_model.bin\n",
"Image processor saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-9/preprocessor_config.json\n",
"Image processor saved in vit-base-patch16-224-in21k-finetuned-lora-food101/preprocessor_config.json\n",
"***** Running Evaluation *****\n",
" Num examples = 500\n",
" Batch size = 128\n",
"Saving model checkpoint to vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-18\n",
"Configuration saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-18/config.json\n",
"Model weights saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-18/pytorch_model.bin\n",
"Image processor saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-18/preprocessor_config.json\n",
"***** Running Evaluation *****\n",
" Num examples = 500\n",
" Batch size = 128\n",
"Saving model checkpoint to vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-27\n",
"Configuration saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-27/config.json\n",
"Model weights saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-27/pytorch_model.bin\n",
"Image processor saved in vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-27/preprocessor_config.json\n",
"\n",
"\n",
"Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\n",
"\n",
"Loading best model from vit-base-patch16-224-in21k-finetuned-lora-food101/checkpoint-18 (score: 0.266).\n"
]
}
],
"source": [
"trainer = Trainer(\n",
" model,\n",
" args,\n",
" train_dataset=train_ds,\n",
" eval_dataset=val_ds,\n",
" tokenizer=image_processor,\n",
" compute_metrics=compute_metrics,\n",
" data_collator=collate_fn,\n",
")\n",
"train_results = trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "a7e43c37-06cb-44a2-a9cd-4836a582d8d5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"***** Running Evaluation *****\n",
" Num examples = 500\n",
" Batch size = 128\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='4' max='4' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [4/4 00:05]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'eval_loss': 4.2555084228515625,\n",
" 'eval_accuracy': 0.266,\n",
" 'eval_runtime': 7.5852,\n",
" 'eval_samples_per_second': 65.918,\n",
" 'eval_steps_per_second': 0.527,\n",
" 'epoch': 3.0}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluate(val_ds)"
]
}
],
"metadata": {
"environment": {
"kernel": "py38",
"name": "common-cu110.m102",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m102"
},
"kernelspec": {
"display_name": "py38",
"language": "python",
"name": "py38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment