Skip to content

Instantly share code, notes, and snippets.

@metric-space
Created May 19, 2023 22:22
Show Gist options
  • Save metric-space/a315a97d581039cd47888c24caba9f65 to your computer and use it in GitHub Desktop.
Save metric-space/a315a97d581039cd47888c24caba9f65 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"\n",
"**Best-of-n sampling class usage**\n",
"\n"
],
"metadata": {
"id": "WQpNapZNWuXP"
}
},
{
"cell_type": "markdown",
"source": [
"Import dependencies\n"
],
"metadata": {
"id": "Lo98lkdP66_x"
}
},
{
"cell_type": "code",
"source": [
"%pip install torch datasets transformers git+https://github.com/metric-space/trl.git@140/best-of-n-sampling-class"
],
"metadata": {
"id": "vDA6qayz692w"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import pandas as pd\n",
"from transformers import pipeline, AutoTokenizer\n",
"from datasets import load_dataset\n",
"\n",
"from trl import AutoModelForCausalLMWithValueHead\n",
"from trl.core import LengthSampler\n",
"from trl.extras import BestOfNSampler\n",
"\n",
"device = 0 if torch.cuda.is_available() else \"cpu\" "
],
"metadata": {
"id": "M1s_iNm773hM"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Various constants"
],
"metadata": {
"id": "Y7hyrIrO8tcY"
}
},
{
"cell_type": "code",
"source": [
"ref_model_name = 'lvwerra/gpt2-imdb'\n",
"model_name = 'lvwerra/gpt2-imdb-pos-v2'\n",
"reward_model = 'lvwerra/distilbert-imdb'\n",
" \n",
"N_BEST_OF = 4"
],
"metadata": {
"id": "MqS3OM6Q8x6g"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Models and tokenizers "
],
"metadata": {
"id": "c1YcXeElg6or"
}
},
{
"cell_type": "code",
"source": [
"\n",
"ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n",
"\n",
"reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n",
"\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"# cuda-ize models\n",
"ref_model.cuda()"
],
"metadata": {
"id": "b855NrL181Hh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Dataset building"
],
"metadata": {
"id": "Z1Cz0gCFhZYJ"
}
},
{
"cell_type": "code",
"source": [
"def build_dataset(tokenizer, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n",
" # load imdb with datasets\n",
" ds = load_dataset(dataset_name, split=\"train\")\n",
" ds = ds.rename_columns({\"text\": \"review\"})\n",
" ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n",
"\n",
" input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
"\n",
" def tokenize(sample):\n",
" sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
" sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
" return sample\n",
"\n",
" ds = ds.map(tokenize, batched=False)\n",
" ds.set_format(type=\"torch\")\n",
" return ds\n",
"\n",
"dataset = build_dataset(tokenizer)"
],
"metadata": {
"id": "LqLVEp5p_8XM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}\n",
"sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
],
"metadata": {
"id": "AqA2McjMAxNw"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"output_min_length = 4\n",
"output_max_length = 16\n",
"output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
"\n",
"#### get a batch from the dataset\n",
"bs = 16\n",
"output_data = dict()\n",
"dataset.set_format(\"pandas\")\n",
"df_batch = dataset[:].sample(bs)\n",
"output_data[\"query\"] = df_batch[\"query\"].tolist()\n",
"query_tensors = df_batch[\"input_ids\"].tolist()\n",
"\n",
"# :: [Resp]\n",
"response_tensors_ref, response_tensors = [], []\n",
"# :: [[Resp]]\n",
"response_tensors_best_of = []\n",
"\n"
],
"metadata": {
"id": "L_q4qs35AxcR"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"a = BestOfNSampler(ref_model, tokenizer, reward_pipe, reward_kwargs=sent_kwargs, length_sampler=output_length_sampler)\n",
"a.generate(query_tensors, device=device, **gen_kwargs)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wDv5wz5DiTw4",
"outputId": "d95e4bcc-fccd-4102-8934-768628ab9975"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py:1080: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n",
" warnings.warn(\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['I rented this film purely on the premise of looking at an aspiring actress who is very',\n",
" 'An independent feature can now be seen via premium cinema where the ever as good Alice and Richard',\n",
" 'When I saw this movie, I really enjoyed it',\n",
" 'This movie has an all-time high for movie retellings, comedy that seems',\n",
" 'I first saw this film about 7 years ago. I was amazed about Finney',\n",
" 'A previous reviewer said the film has an ugly ending, but through the tender moments',\n",
" 'To make any film a buddy to. When the script is too hack',\n",
" 'A recent post here by Michael Tuul-Jung suggests that Newton did this',\n",
" 'Though the award-winning doc',\n",
" 'Steven Seagal, was in Zombie after World War II? And',\n",
" '\"Plants recall something very different this time, living in the',\n",
" \"This is the only movie I've seen that show him that deserves better than 1/10- with the ending\",\n",
" 'I saw Chan Is Missing when it came out several years ago.<',\n",
" 'Not that I want to be mean anymore. I felt like',\n",
" \"I'm sure that rented out all but a\",\n",
" 'If you want to know some more stories you can']"
]
},
"metadata": {},
"execution_count": 8
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment