Created
August 30, 2023 09:59
-
-
Save gromgull/0acb75e94be8336600f104d96b6ee068 to your computer and use it in GitHub Desktop.
WordPiece Tokenizer tutorial from Huggingface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "2ad76424", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"corpus = [\n", | |
" \"This is the Hugging Face Course.\",\n", | |
" \"This chapter is about tokenization.\",\n", | |
" \"This section shows several tokenizer algorithms.\",\n", | |
" \"Hopefully, you will be able to understand how they are trained and generate tokens.\",\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "423e51e6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"2023-08-30 10:21:48.102686: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", | |
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "41f947d40b6644909567b3e2ce5bd972", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Downloading: 0%| | 0.00/29.0 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "342f8f82db634ba0b34adc9cf302e02e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Downloading: 0%| | 0.00/570 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "540f5a70194e48a28de77036cde5a98a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Downloading: 0%| | 0.00/213k [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "83b10bfbe4754b9480ba5d858aea6fae", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Downloading: 0%| | 0.00/436k [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"from transformers import AutoTokenizer\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "31dc11ac", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"defaultdict(int,\n", | |
" {'This': 3,\n", | |
" 'is': 2,\n", | |
" 'the': 1,\n", | |
" 'Hugging': 1,\n", | |
" 'Face': 1,\n", | |
" 'Course': 1,\n", | |
" '.': 4,\n", | |
" 'chapter': 1,\n", | |
" 'about': 1,\n", | |
" 'tokenization': 1,\n", | |
" 'section': 1,\n", | |
" 'shows': 1,\n", | |
" 'several': 1,\n", | |
" 'tokenizer': 1,\n", | |
" 'algorithms': 1,\n", | |
" 'Hopefully': 1,\n", | |
" ',': 1,\n", | |
" 'you': 1,\n", | |
" 'will': 1,\n", | |
" 'be': 1,\n", | |
" 'able': 1,\n", | |
" 'to': 1,\n", | |
" 'understand': 1,\n", | |
" 'how': 1,\n", | |
" 'they': 1,\n", | |
" 'are': 1,\n", | |
" 'trained': 1,\n", | |
" 'and': 1,\n", | |
" 'generate': 1,\n", | |
" 'tokens': 1})" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from collections import defaultdict\n", | |
"\n", | |
"word_freqs = defaultdict(int)\n", | |
"for text in corpus:\n", | |
" words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)\n", | |
" new_words = [word for word, offset in words_with_offsets]\n", | |
" for word in new_words:\n", | |
" word_freqs[word] += 1\n", | |
"\n", | |
"word_freqs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "b048c642", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y']\n" | |
] | |
} | |
], | |
"source": [ | |
"alphabet = []\n", | |
"for word in word_freqs.keys():\n", | |
" if word[0] not in alphabet:\n", | |
" alphabet.append(word[0])\n", | |
" for letter in word[1:]:\n", | |
" if f\"##{letter}\" not in alphabet:\n", | |
" alphabet.append(f\"##{letter}\")\n", | |
"\n", | |
"alphabet.sort()\n", | |
"alphabet\n", | |
"\n", | |
"print(alphabet)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "78c709f7", | |
"metadata": {}, | |
"source": [ | |
"\"today\" => t, ##o ##d " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "2a97425b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab = [\"[PAD]\", \"[UNK]\", \"[CLS]\", \"[SEP]\", \"[MASK]\"] + alphabet.copy()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "cfe2d694", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"splits = {\n", | |
" word: [c if i == 0 else f\"##{c}\" for i, c in enumerate(word)]\n", | |
" for word in word_freqs.keys()\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "7bb20d84", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def compute_pair_scores(splits):\n", | |
" letter_freqs = defaultdict(int)\n", | |
" pair_freqs = defaultdict(int)\n", | |
" for word, freq in word_freqs.items():\n", | |
" split = splits[word]\n", | |
" if len(split) == 1:\n", | |
" letter_freqs[split[0]] += freq\n", | |
" continue\n", | |
" for i in range(len(split) - 1):\n", | |
" pair = (split[i], split[i + 1])\n", | |
" letter_freqs[split[i]] += freq\n", | |
" pair_freqs[pair] += freq\n", | |
" letter_freqs[split[-1]] += freq\n", | |
"\n", | |
" scores = {\n", | |
" pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])\n", | |
" for pair, freq in pair_freqs.items()\n", | |
" }\n", | |
" return scores" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "5d2c1ccd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"('T', '##h'): 0.125\n", | |
"('##h', '##i'): 0.03409090909090909\n", | |
"('##i', '##s'): 0.02727272727272727\n", | |
"('i', '##s'): 0.1\n", | |
"('t', '##h'): 0.03571428571428571\n", | |
"('##h', '##e'): 0.011904761904761904\n" | |
] | |
} | |
], | |
"source": [ | |
"pair_scores = compute_pair_scores(splits)\n", | |
"for i, key in enumerate(pair_scores.keys()):\n", | |
" print(f\"{key}: {pair_scores[key]}\")\n", | |
" if i >= 5:\n", | |
" break\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "2b8368b3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"('a', '##b') 0.2\n" | |
] | |
} | |
], | |
"source": [ | |
"best_pair = \"\"\n", | |
"max_score = None\n", | |
"for pair, score in pair_scores.items():\n", | |
" if max_score is None or max_score < score:\n", | |
" best_pair = pair\n", | |
" max_score = score\n", | |
"\n", | |
"print(best_pair, max_score)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "b01a4ff1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab.append(\"ab\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "37667b0e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def merge_pair(a, b, splits):\n", | |
" for word in word_freqs:\n", | |
" split = splits[word]\n", | |
" if len(split) == 1:\n", | |
" continue\n", | |
" i = 0\n", | |
" while i < len(split) - 1:\n", | |
" if split[i] == a and split[i + 1] == b:\n", | |
" merge = a + b[2:] if b.startswith(\"##\") else a + b\n", | |
" split = split[:i] + [merge] + split[i + 2 :]\n", | |
" else:\n", | |
" i += 1\n", | |
" splits[word] = split\n", | |
" return splits\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "86294ab2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['ab', '##o', '##u', '##t']" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"splits = merge_pair(\"a\", \"##b\", splits)\n", | |
"splits[\"about\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "e8f65e92", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab_size = 70\n", | |
"while len(vocab) < vocab_size:\n", | |
" scores = compute_pair_scores(splits)\n", | |
" best_pair, max_score = \"\", None\n", | |
" for pair, score in scores.items():\n", | |
" if max_score is None or max_score < score:\n", | |
" best_pair = pair\n", | |
" max_score = score\n", | |
" splits = merge_pair(*best_pair, splits)\n", | |
" new_token = (\n", | |
" best_pair[0] + best_pair[1][2:]\n", | |
" if best_pair[1].startswith(\"##\")\n", | |
" else best_pair[0] + best_pair[1]\n", | |
" )\n", | |
" vocab.append(new_token)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "1f5cb209", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y', 'ab', '##fu', 'Fa', 'Fac', '##ct', '##ful', '##full', '##fully', 'Th', 'ch', '##hm', 'cha', 'chap', 'chapt', '##thm', 'Hu', 'Hug', 'Hugg', 'sh', 'th', 'is', '##thms', '##za', '##zat', '##ut']\n" | |
] | |
} | |
], | |
"source": [ | |
"print(vocab)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "0f237ae9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def encode_word(word):\n", | |
" tokens = []\n", | |
" while len(word) > 0:\n", | |
" i = len(word)\n", | |
" while i > 0 and word[:i] not in vocab:\n", | |
" i -= 1\n", | |
" if i == 0:\n", | |
" return [\"[UNK]\"]\n", | |
" tokens.append(word[:i])\n", | |
" word = word[i:]\n", | |
" if len(word) > 0:\n", | |
" word = f\"##{word}\"\n", | |
" return tokens" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "d35e52ff", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['Hugg', '##i', '##n', '##g']\n", | |
"['[UNK]']\n" | |
] | |
} | |
], | |
"source": [ | |
"print(encode_word(\"Hugging\"))\n", | |
"print(encode_word(\"HOgging\"))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "cd602f7c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def tokenize(text):\n", | |
" pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)\n", | |
" pre_tokenized_text = [word for word, offset in pre_tokenize_result]\n", | |
" encoded_words = [encode_word(word) for word in pre_tokenized_text]\n", | |
" return sum(encoded_words, [])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "12642dfe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['Th',\n", | |
" '##i',\n", | |
" '##s',\n", | |
" 'is',\n", | |
" 'th',\n", | |
" '##e',\n", | |
" 'Hugg',\n", | |
" '##i',\n", | |
" '##n',\n", | |
" '##g',\n", | |
" 'Fac',\n", | |
" '##e',\n", | |
" 'c',\n", | |
" '##o',\n", | |
" '##u',\n", | |
" '##r',\n", | |
" '##s',\n", | |
" '##e',\n", | |
" '[UNK]']" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tokenize(\"This is the Hugging Face course!\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8ecbb8d1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment