Skip to content

Instantly share code, notes, and snippets.

@radekosmulski
Created July 15, 2023 12:40
Show Gist options
  • Save radekosmulski/00dc4b4915ea38e79c1c647c90d757c3 to your computer and use it in GitHub Desktop.
Save radekosmulski/00dc4b4915ea38e79c1c647c90d757c3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "b284d8a4",
"metadata": {},
"source": [
"## F1 score on bigrams"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4dff0ed2",
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import string\n",
"from nltk.util import bigrams\n",
"from collections import Counter\n",
"\n",
"regex_punctuation = re.compile('[%s]' % re.escape(string.punctuation))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d3d7f995",
"metadata": {},
"outputs": [],
"source": [
"gt = \"Vanilla is the best ice cream flavor in the world.\"\n",
"pred = \"Vanilla.\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c225e095",
"metadata": {},
"outputs": [],
"source": [
"def string_to_tokens(s):\n",
" '''\n",
" This is a very basic way of tokenizing your text.\n",
" Probably not useful for real life datasets.\n",
" '''\n",
" return ['<bos>'] + re.sub(regex_punctuation, '' , s.lower()).split() + ['<eos>']"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1f0bb882",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<bos>', 'vanilla', 'is', 'the', 'best', 'ice', 'cream', 'flavor', 'in', 'the', 'world', '<eos>'] ['<bos>', 'vanilla', '<eos>']\n"
]
}
],
"source": [
"gt_words = string_to_tokens(gt)\n",
"pred_words = string_to_tokens(pred)\n",
"print(gt_words, pred_words)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "27f1721c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('<bos>', 'vanilla'), ('vanilla', 'is'), ('is', 'the'), ('the', 'best'), ('best', 'ice'), ('ice', 'cream'), ('cream', 'flavor'), ('flavor', 'in'), ('in', 'the'), ('the', 'world'), ('world', '<eos>')] [('<bos>', 'vanilla'), ('vanilla', '<eos>')]\n"
]
}
],
"source": [
"gt_bigrams = list(bigrams(gt_words)) \n",
"pred_bigrams = list(bigrams(pred_words)) \n",
"print(gt_bigrams, pred_bigrams)"
]
},
{
"cell_type": "markdown",
"id": "ff6c9cff",
"metadata": {},
"source": [
"![f1 formula](https://wikimedia.org/api/rest_v1/media/math/render/svg/f5c869c51dba6f1df65a6e6630c516de161632d4)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "eda31f54",
"metadata": {},
"outputs": [],
"source": [
"def f1(pred_bigrams, gt_bigrams):\n",
" shared_ngrams = Counter(pred_bigrams) & Counter(gt_bigrams)\n",
" num_same = sum(shared_ngrams.values())\n",
" \n",
" if num_same == 0: return 0\n",
" precision = 1.0 * num_same / len(pred_bigrams)\n",
" recall = 1.0 * num_same / len(gt_bigrams)\n",
" f1 = (2 * precision * recall) / (precision + recall)\n",
" return f1"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8613bb0e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.15384615384615385"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f1(pred_bigrams, gt_bigrams)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "056863bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8181818181818182"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred_bigrams = list(bigrams(string_to_tokens('Chocolate is the best ice cream flavor in the world.')))\n",
"f1(pred_bigrams, gt_bigrams)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8462c989",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8695652173913043"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred_bigrams = list(bigrams(string_to_tokens('Vanilla is not the best ice cream flavor in the world.')))\n",
"f1(pred_bigrams, gt_bigrams)"
]
},
{
"cell_type": "markdown",
"id": "d850c8a3",
"metadata": {},
"source": [
"## BERTscore"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2dc1d5f8",
"metadata": {},
"outputs": [],
"source": [
"# !pip install evaluate\n",
"# !pip install bert_score\n",
"# !pip install pyarrow==11.0.0"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9da20328",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/radek/miniconda3/envs/pytorch/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
" import evaluate\n",
"\n",
"bertscore = evaluate.load('bertscore')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ae052b41",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'precision': [0.6896928548812866],\n",
" 'recall': [0.7891448736190796],\n",
" 'f1': [0.7360748052597046],\n",
" 'hashcode': 'distilbert-base-uncased_L5_no-idf_version=0.3.12(hug_trans=4.29.2)'}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bertscore.compute(\n",
" predictions=['Vanilla is the best ice cream flavor in the world.'],\n",
" references=['Vanilla.'],\n",
" model_type=\"distilbert-base-uncased\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "400c40ff",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'precision': [0.9863595962524414],\n",
" 'recall': [0.9863595962524414],\n",
" 'f1': [0.9863595962524414],\n",
" 'hashcode': 'distilbert-base-uncased_L5_no-idf_version=0.3.12(hug_trans=4.29.2)'}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bertscore.compute(\n",
" predictions=['Vanilla is the best ice cream flavor in the world.'],\n",
" references=['Chocolate is the best ice cream flavor in the world.'],\n",
" model_type=\"distilbert-base-uncased\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "fe4e3421",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'precision': [0.9881718158721924],\n",
" 'recall': [0.9713627099990845],\n",
" 'f1': [0.979695200920105],\n",
" 'hashcode': 'distilbert-base-uncased_L5_no-idf_version=0.3.12(hug_trans=4.29.2)'}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bertscore.compute(\n",
" predictions=['Vanilla is the best ice cream flavor in the world.'],\n",
" references=['Vanilla is not the best ice cream flavor in the world.'],\n",
" model_type=\"distilbert-base-uncased\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "cd316b0c",
"metadata": {},
"source": [
"## Perplexity"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a56cd1cf",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"import transformers\n",
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"\n",
"# Code from a wonderful Kaggle Notebook by Pilipp Singer\n",
"# https://www.kaggle.com/code/philippsinger/h2ogpt-perplexity-ranking\n",
"\n",
"class Perplexity(nn.Module):\n",
" def __init__(self, reduce: bool = True):\n",
" super().__init__()\n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
" self.reduce = reduce\n",
"\n",
" def forward(self, logits, labels):\n",
" shift_logits = logits[..., :-1, :].contiguous()\n",
" shift_labels = labels[..., 1:].contiguous()\n",
"\n",
" perplexity = []\n",
" for i in range(labels.shape[0]):\n",
" perplexity.append(self.loss_fn(shift_logits[i], shift_labels[i]))\n",
" perplexity = torch.stack(perplexity, dim=0)\n",
" if self.reduce:\n",
" perplexity = torch.mean(perplexity)\n",
" return perplexity \n",
" \n",
"perp = Perplexity()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "1d7fa1e6",
"metadata": {},
"outputs": [],
"source": [
"def perplexity(model, prompt):\n",
" tokenizer = AutoTokenizer.from_pretrained(model)\n",
"\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" model,\n",
" torch_dtype=torch.float16,\n",
" device_map=\"auto\",\n",
" trust_remote_code=True,\n",
" )\n",
" \n",
" with torch.no_grad():\n",
" inputs = tokenizer([prompt], return_tensors=\"pt\", add_special_tokens=False, truncation=True).to(\"cuda\")\n",
" logits = model(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"]).logits\n",
" labels = inputs[\"input_ids\"]\n",
" return perp(logits[0].unsqueeze(0), labels[0].unsqueeze(0)).item()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "af413bfa",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"\"\"\n",
" Q: Which ice cream flavor is the best?\n",
" A: Vanilla is the best ice cream flavor in the world.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "28d6c4a8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00, 3.19s/it]\n"
]
},
{
"data": {
"text/plain": [
"2.111328125"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"perplexity(\"tiiuae/falcon-7b\", prompt)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "da110e26",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Instantiating an MPTForCausalLM model from /home/radek/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7/modeling_mpt.py\n",
"You are using config.init_device='cpu', but you can also use config.init_device=\"meta\" with Composer + FSDP for fast initialization.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00, 3.07s/it]\n"
]
},
{
"data": {
"text/plain": [
"1.77734375"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"perplexity(\"mosaicml/mpt-7b\", prompt)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment