Skip to content

Instantly share code, notes, and snippets.

@gromgull
Created August 30, 2023 09:59
Show Gist options
  • Save gromgull/0acb75e94be8336600f104d96b6ee068 to your computer and use it in GitHub Desktop.
Save gromgull/0acb75e94be8336600f104d96b6ee068 to your computer and use it in GitHub Desktop.
WordPiece Tokenizer tutorial from Huggingface
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "2ad76424",
"metadata": {},
"outputs": [],
"source": [
"corpus = [\n",
" \"This is the Hugging Face Course.\",\n",
" \"This chapter is about tokenization.\",\n",
" \"This section shows several tokenizer algorithms.\",\n",
" \"Hopefully, you will be able to understand how they are trained and generate tokens.\",\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "423e51e6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-08-30 10:21:48.102686: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "41f947d40b6644909567b3e2ce5bd972",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/29.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "342f8f82db634ba0b34adc9cf302e02e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/570 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "540f5a70194e48a28de77036cde5a98a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/213k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "83b10bfbe4754b9480ba5d858aea6fae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/436k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "31dc11ac",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"defaultdict(int,\n",
" {'This': 3,\n",
" 'is': 2,\n",
" 'the': 1,\n",
" 'Hugging': 1,\n",
" 'Face': 1,\n",
" 'Course': 1,\n",
" '.': 4,\n",
" 'chapter': 1,\n",
" 'about': 1,\n",
" 'tokenization': 1,\n",
" 'section': 1,\n",
" 'shows': 1,\n",
" 'several': 1,\n",
" 'tokenizer': 1,\n",
" 'algorithms': 1,\n",
" 'Hopefully': 1,\n",
" ',': 1,\n",
" 'you': 1,\n",
" 'will': 1,\n",
" 'be': 1,\n",
" 'able': 1,\n",
" 'to': 1,\n",
" 'understand': 1,\n",
" 'how': 1,\n",
" 'they': 1,\n",
" 'are': 1,\n",
" 'trained': 1,\n",
" 'and': 1,\n",
" 'generate': 1,\n",
" 'tokens': 1})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from collections import defaultdict\n",
"\n",
"word_freqs = defaultdict(int)\n",
"for text in corpus:\n",
" words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
" new_words = [word for word, offset in words_with_offsets]\n",
" for word in new_words:\n",
" word_freqs[word] += 1\n",
"\n",
"word_freqs"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b048c642",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y']\n"
]
}
],
"source": [
"alphabet = []\n",
"for word in word_freqs.keys():\n",
" if word[0] not in alphabet:\n",
" alphabet.append(word[0])\n",
" for letter in word[1:]:\n",
" if f\"##{letter}\" not in alphabet:\n",
" alphabet.append(f\"##{letter}\")\n",
"\n",
"alphabet.sort()\n",
"alphabet\n",
"\n",
"print(alphabet)"
]
},
{
"cell_type": "markdown",
"id": "78c709f7",
"metadata": {},
"source": [
"\"today\" => t, ##o ##d "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2a97425b",
"metadata": {},
"outputs": [],
"source": [
"vocab = [\"[PAD]\", \"[UNK]\", \"[CLS]\", \"[SEP]\", \"[MASK]\"] + alphabet.copy()\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cfe2d694",
"metadata": {},
"outputs": [],
"source": [
"splits = {\n",
" word: [c if i == 0 else f\"##{c}\" for i, c in enumerate(word)]\n",
" for word in word_freqs.keys()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7bb20d84",
"metadata": {},
"outputs": [],
"source": [
"def compute_pair_scores(splits):\n",
" letter_freqs = defaultdict(int)\n",
" pair_freqs = defaultdict(int)\n",
" for word, freq in word_freqs.items():\n",
" split = splits[word]\n",
" if len(split) == 1:\n",
" letter_freqs[split[0]] += freq\n",
" continue\n",
" for i in range(len(split) - 1):\n",
" pair = (split[i], split[i + 1])\n",
" letter_freqs[split[i]] += freq\n",
" pair_freqs[pair] += freq\n",
" letter_freqs[split[-1]] += freq\n",
"\n",
" scores = {\n",
" pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])\n",
" for pair, freq in pair_freqs.items()\n",
" }\n",
" return scores"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5d2c1ccd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('T', '##h'): 0.125\n",
"('##h', '##i'): 0.03409090909090909\n",
"('##i', '##s'): 0.02727272727272727\n",
"('i', '##s'): 0.1\n",
"('t', '##h'): 0.03571428571428571\n",
"('##h', '##e'): 0.011904761904761904\n"
]
}
],
"source": [
"pair_scores = compute_pair_scores(splits)\n",
"for i, key in enumerate(pair_scores.keys()):\n",
" print(f\"{key}: {pair_scores[key]}\")\n",
" if i >= 5:\n",
" break\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2b8368b3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('a', '##b') 0.2\n"
]
}
],
"source": [
"best_pair = \"\"\n",
"max_score = None\n",
"for pair, score in pair_scores.items():\n",
" if max_score is None or max_score < score:\n",
" best_pair = pair\n",
" max_score = score\n",
"\n",
"print(best_pair, max_score)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b01a4ff1",
"metadata": {},
"outputs": [],
"source": [
"vocab.append(\"ab\")\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "37667b0e",
"metadata": {},
"outputs": [],
"source": [
"def merge_pair(a, b, splits):\n",
" for word in word_freqs:\n",
" split = splits[word]\n",
" if len(split) == 1:\n",
" continue\n",
" i = 0\n",
" while i < len(split) - 1:\n",
" if split[i] == a and split[i + 1] == b:\n",
" merge = a + b[2:] if b.startswith(\"##\") else a + b\n",
" split = split[:i] + [merge] + split[i + 2 :]\n",
" else:\n",
" i += 1\n",
" splits[word] = split\n",
" return splits\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "86294ab2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['ab', '##o', '##u', '##t']"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"splits = merge_pair(\"a\", \"##b\", splits)\n",
"splits[\"about\"]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e8f65e92",
"metadata": {},
"outputs": [],
"source": [
"vocab_size = 70\n",
"while len(vocab) < vocab_size:\n",
" scores = compute_pair_scores(splits)\n",
" best_pair, max_score = \"\", None\n",
" for pair, score in scores.items():\n",
" if max_score is None or max_score < score:\n",
" best_pair = pair\n",
" max_score = score\n",
" splits = merge_pair(*best_pair, splits)\n",
" new_token = (\n",
" best_pair[0] + best_pair[1][2:]\n",
" if best_pair[1].startswith(\"##\")\n",
" else best_pair[0] + best_pair[1]\n",
" )\n",
" vocab.append(new_token)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1f5cb209",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y', 'ab', '##fu', 'Fa', 'Fac', '##ct', '##ful', '##full', '##fully', 'Th', 'ch', '##hm', 'cha', 'chap', 'chapt', '##thm', 'Hu', 'Hug', 'Hugg', 'sh', 'th', 'is', '##thms', '##za', '##zat', '##ut']\n"
]
}
],
"source": [
"print(vocab)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0f237ae9",
"metadata": {},
"outputs": [],
"source": [
"def encode_word(word):\n",
" tokens = []\n",
" while len(word) > 0:\n",
" i = len(word)\n",
" while i > 0 and word[:i] not in vocab:\n",
" i -= 1\n",
" if i == 0:\n",
" return [\"[UNK]\"]\n",
" tokens.append(word[:i])\n",
" word = word[i:]\n",
" if len(word) > 0:\n",
" word = f\"##{word}\"\n",
" return tokens"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d35e52ff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Hugg', '##i', '##n', '##g']\n",
"['[UNK]']\n"
]
}
],
"source": [
"print(encode_word(\"Hugging\"))\n",
"print(encode_word(\"HOgging\"))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "cd602f7c",
"metadata": {},
"outputs": [],
"source": [
"def tokenize(text):\n",
" pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
" pre_tokenized_text = [word for word, offset in pre_tokenize_result]\n",
" encoded_words = [encode_word(word) for word in pre_tokenized_text]\n",
" return sum(encoded_words, [])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "12642dfe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Th',\n",
" '##i',\n",
" '##s',\n",
" 'is',\n",
" 'th',\n",
" '##e',\n",
" 'Hugg',\n",
" '##i',\n",
" '##n',\n",
" '##g',\n",
" 'Fac',\n",
" '##e',\n",
" 'c',\n",
" '##o',\n",
" '##u',\n",
" '##r',\n",
" '##s',\n",
" '##e',\n",
" '[UNK]']"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenize(\"This is the Hugging Face course!\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ecbb8d1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment