Skip to content

Instantly share code, notes, and snippets.

@SauravMaheshkar
Created March 17, 2024 23:30
Show Gist options
  • Save SauravMaheshkar/f01daa85639d23e817b6450e1d0eb45f to your computer and use it in GitHub Desktop.
Save SauravMaheshkar/f01daa85639d23e817b6450e1d0eb45f to your computer and use it in GitHub Desktop.
p-tuning.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"cell_execution_strategy": "setup",
"gpuType": "T4",
"authorship_tag": "ABX9TyPnUBWluJ5e8QUDZIw6TZ+z",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/SauravMaheshkar/f01daa85639d23e817b6450e1d0eb45f/p-tuning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"📦 Packages and Basic Setup\n",
"---"
],
"metadata": {
"id": "z2w5BUSjkWuu"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kytLsvmzhXQ7"
},
"outputs": [],
"source": [
"%%capture\n",
"!pip install -q peft transformers datasets evaluate wandb ml-collections"
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"from google.colab import userdata\n",
"\n",
"key = userdata.get('W&B')\n",
"os.environ[\"WANDB_API_KEY\"] = key"
],
"metadata": {
"id": "3R7I2dzvklKb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# @title ⚙️ Configuration\n",
"\n",
"import ml_collections\n",
"\n",
"def get_config() -> ml_collections.ConfigDict:\n",
" config = ml_collections.ConfigDict()\n",
" config.model: str = \"roberta-base\" # @param {type: \"string\"}\n",
" config.task: str = \"mrpc\" # @param {type: \"string\"}\n",
" config.batch_size: int = 128 # @param {type: \"number\"}\n",
" config.num_epochs: int = 10 # @param {type: \"number\"}\n",
" config.learning_rate: int = 1e-3 # @param {type: \"number\"}\n",
" config.dataset: str = \"glue\" # @param {type: \"string\"}\n",
" config.wandb_entity: str = \"sauravmaheshkar\" # @param {type: \"string\"}\n",
"\n",
" return config\n",
"\n",
"config = get_config()"
],
"metadata": {
"id": "VpTZcxcLklbH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import wandb\n",
"\n",
"wandb.init(project=\"softprompts\", entity=config.wandb_entity, job_type=\"train\", group=\"p-tuning\", config = config.to_dict())\n",
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"os.environ[\"WANDB_WATCH\"]=\"false\"\n",
"os.environ[\"WANDB_LOG_MODEL\"]=\"true\""
],
"metadata": {
"id": "SuHeVMBvlGcW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 💿 The Dataset\n",
"---"
],
"metadata": {
"id": "SfYg0fyHlSR0"
}
},
{
"cell_type": "code",
"source": [
"from datasets import load_dataset\n",
"\n",
"dataset = load_dataset(config.dataset, config.task)"
],
"metadata": {
"id": "p3HEfwKxlTiK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
"\n",
"if any(k in config.model for k in (\"gpt\", \"opt\", \"bloom\")):\n",
" padding_side = \"left\"\n",
"else:\n",
" padding_side = \"right\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(config.model, padding_side=padding_side)\n",
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
"\n",
"def tokenize_function(examples):\n",
" # max_length=None => use the model max length (it's actually the default)\n",
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
" return outputs\n",
"\n",
"tokenized_datasets = dataset.map(\n",
" tokenize_function,\n",
" batched=True,\n",
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
")\n",
"\n",
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
"\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=\"longest\")"
],
"metadata": {
"id": "Sw6gudKLlY99"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## ✍️ Model Architecture & Training\n",
"---"
],
"metadata": {
"id": "BnjR9fvdltI4"
}
},
{
"cell_type": "code",
"source": [
"from peft import PromptEncoderConfig\n",
"\n",
"peft_config = PromptEncoderConfig(\n",
" task_type=\"SEQ_CLS\",\n",
" num_virtual_tokens=20,\n",
" encoder_hidden_size=128\n",
")"
],
"metadata": {
"id": "3YWoOGKtluNT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from peft import get_peft_model\n",
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(config.model, return_dict=True)\n",
"model = get_peft_model(model, peft_config)"
],
"metadata": {
"id": "QEPm1nDtl4Jg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import evaluate\n",
"import numpy as np\n",
"\n",
"metric = evaluate.load(config.dataset, config.task)\n",
"\n",
"def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" predictions = np.argmax(predictions, axis=1)\n",
" return metric.compute(predictions=predictions, references=labels)"
],
"metadata": {
"id": "iv2lF5dJmbyW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from transformers import TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=f\"{config.model}-peft-p-tuning\",\n",
" learning_rate=config.learning_rate,\n",
" per_device_train_batch_size=config.batch_size,\n",
" per_device_eval_batch_size=config.batch_size,\n",
" num_train_epochs=config.num_epochs,\n",
" weight_decay=0.01,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" load_best_model_at_end=True,\n",
" report_to=[\"wandb\"],\n",
")"
],
"metadata": {
"id": "N77mAHAZmFg2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from transformers import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_datasets[\"train\"],\n",
" eval_dataset=tokenized_datasets[\"test\"],\n",
" tokenizer=tokenizer,\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"train_results = trainer.train()"
],
"metadata": {
"id": "vVXXO0UnmZ_5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"wandb.config.train_results = train_results\n",
"wandb.finish()"
],
"metadata": {
"id": "kuR6MhOunIYR"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment