Created
March 17, 2024 23:30
-
-
Save SauravMaheshkar/f01daa85639d23e817b6450e1d0eb45f to your computer and use it in GitHub Desktop.
p-tuning.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"private_outputs": true, | |
"provenance": [], | |
"cell_execution_strategy": "setup", | |
"gpuType": "T4", | |
"authorship_tag": "ABX9TyPnUBWluJ5e8QUDZIw6TZ+z", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/SauravMaheshkar/f01daa85639d23e817b6450e1d0eb45f/p-tuning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"📦 Packages and Basic Setup\n", | |
"---" | |
], | |
"metadata": { | |
"id": "z2w5BUSjkWuu" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "kytLsvmzhXQ7" | |
}, | |
"outputs": [], | |
"source": [ | |
"%%capture\n", | |
"!pip install -q peft transformers datasets evaluate wandb ml-collections" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"from google.colab import userdata\n", | |
"\n", | |
"key = userdata.get('W&B')\n", | |
"os.environ[\"WANDB_API_KEY\"] = key" | |
], | |
"metadata": { | |
"id": "3R7I2dzvklKb" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# @title ⚙️ Configuration\n", | |
"\n", | |
"import ml_collections\n", | |
"\n", | |
"def get_config() -> ml_collections.ConfigDict:\n", | |
" config = ml_collections.ConfigDict()\n", | |
" config.model: str = \"roberta-base\" # @param {type: \"string\"}\n", | |
" config.task: str = \"mrpc\" # @param {type: \"string\"}\n", | |
" config.batch_size: int = 128 # @param {type: \"number\"}\n", | |
" config.num_epochs: int = 10 # @param {type: \"number\"}\n", | |
" config.learning_rate: int = 1e-3 # @param {type: \"number\"}\n", | |
" config.dataset: str = \"glue\" # @param {type: \"string\"}\n", | |
" config.wandb_entity: str = \"sauravmaheshkar\" # @param {type: \"string\"}\n", | |
"\n", | |
" return config\n", | |
"\n", | |
"config = get_config()" | |
], | |
"metadata": { | |
"id": "VpTZcxcLklbH" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import wandb\n", | |
"\n", | |
"wandb.init(project=\"softprompts\", entity=config.wandb_entity, job_type=\"train\", group=\"p-tuning\", config = config.to_dict())\n", | |
"\n", | |
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", | |
"os.environ[\"WANDB_WATCH\"]=\"false\"\n", | |
"os.environ[\"WANDB_LOG_MODEL\"]=\"true\"" | |
], | |
"metadata": { | |
"id": "SuHeVMBvlGcW" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 💿 The Dataset\n", | |
"---" | |
], | |
"metadata": { | |
"id": "SfYg0fyHlSR0" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from datasets import load_dataset\n", | |
"\n", | |
"dataset = load_dataset(config.dataset, config.task)" | |
], | |
"metadata": { | |
"id": "p3HEfwKxlTiK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from transformers import AutoTokenizer, DataCollatorWithPadding\n", | |
"\n", | |
"if any(k in config.model for k in (\"gpt\", \"opt\", \"bloom\")):\n", | |
" padding_side = \"left\"\n", | |
"else:\n", | |
" padding_side = \"right\"\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(config.model, padding_side=padding_side)\n", | |
"if getattr(tokenizer, \"pad_token_id\") is None:\n", | |
" tokenizer.pad_token_id = tokenizer.eos_token_id\n", | |
"\n", | |
"def tokenize_function(examples):\n", | |
" # max_length=None => use the model max length (it's actually the default)\n", | |
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n", | |
" return outputs\n", | |
"\n", | |
"tokenized_datasets = dataset.map(\n", | |
" tokenize_function,\n", | |
" batched=True,\n", | |
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n", | |
")\n", | |
"\n", | |
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", | |
"\n", | |
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=\"longest\")" | |
], | |
"metadata": { | |
"id": "Sw6gudKLlY99" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## ✍️ Model Architecture & Training\n", | |
"---" | |
], | |
"metadata": { | |
"id": "BnjR9fvdltI4" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from peft import PromptEncoderConfig\n", | |
"\n", | |
"peft_config = PromptEncoderConfig(\n", | |
" task_type=\"SEQ_CLS\",\n", | |
" num_virtual_tokens=20,\n", | |
" encoder_hidden_size=128\n", | |
")" | |
], | |
"metadata": { | |
"id": "3YWoOGKtluNT" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from peft import get_peft_model\n", | |
"from transformers import AutoModelForSequenceClassification\n", | |
"\n", | |
"model = AutoModelForSequenceClassification.from_pretrained(config.model, return_dict=True)\n", | |
"model = get_peft_model(model, peft_config)" | |
], | |
"metadata": { | |
"id": "QEPm1nDtl4Jg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import evaluate\n", | |
"import numpy as np\n", | |
"\n", | |
"metric = evaluate.load(config.dataset, config.task)\n", | |
"\n", | |
"def compute_metrics(eval_pred):\n", | |
" predictions, labels = eval_pred\n", | |
" predictions = np.argmax(predictions, axis=1)\n", | |
" return metric.compute(predictions=predictions, references=labels)" | |
], | |
"metadata": { | |
"id": "iv2lF5dJmbyW" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from transformers import TrainingArguments\n", | |
"\n", | |
"training_args = TrainingArguments(\n", | |
" output_dir=f\"{config.model}-peft-p-tuning\",\n", | |
" learning_rate=config.learning_rate,\n", | |
" per_device_train_batch_size=config.batch_size,\n", | |
" per_device_eval_batch_size=config.batch_size,\n", | |
" num_train_epochs=config.num_epochs,\n", | |
" weight_decay=0.01,\n", | |
" evaluation_strategy=\"epoch\",\n", | |
" save_strategy=\"epoch\",\n", | |
" load_best_model_at_end=True,\n", | |
" report_to=[\"wandb\"],\n", | |
")" | |
], | |
"metadata": { | |
"id": "N77mAHAZmFg2" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from transformers import Trainer\n", | |
"\n", | |
"trainer = Trainer(\n", | |
" model=model,\n", | |
" args=training_args,\n", | |
" train_dataset=tokenized_datasets[\"train\"],\n", | |
" eval_dataset=tokenized_datasets[\"test\"],\n", | |
" tokenizer=tokenizer,\n", | |
" data_collator=data_collator,\n", | |
" compute_metrics=compute_metrics,\n", | |
")\n", | |
"\n", | |
"train_results = trainer.train()" | |
], | |
"metadata": { | |
"id": "vVXXO0UnmZ_5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"wandb.config.train_results = train_results\n", | |
"wandb.finish()" | |
], | |
"metadata": { | |
"id": "kuR6MhOunIYR" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment