Skip to content

Instantly share code, notes, and snippets.

@xenova
Last active May 10, 2024 00:59
Show Gist options
  • Save xenova/a452a6474428de0182b17605a98631ee to your computer and use it in GitHub Desktop.
Save xenova/a452a6474428de0182b17605a98631ee to your computer and use it in GitHub Desktop.
Convert tiktoken tokenizers to the Hugging Face tokenizers format
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"### Install requirements"
],
"metadata": {
"id": "2krxXyYOEsAj"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qD4__VYRE9ep"
},
"outputs": [],
"source": [
"!pip install -q tiktoken transformers"
]
},
{
"cell_type": "markdown",
"source": [
"### Setup"
],
"metadata": {
"id": "OcCezqFbEvVN"
}
},
{
"cell_type": "code",
"source": [
"\n",
"# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb\n",
"MODEL_INFO = {\n",
" # GPT-2 and GPT-3 models (r50k_base)\n",
" 'gpt2': {\n",
" 'tokenizer_class': 'GPT2Tokenizer',\n",
" 'model_max_length': 1024,\n",
" },\n",
" 'davinci': { # (gpt-3)\n",
" 'tokenizer_class': 'GPT3Tokenizer',\n",
" 'model_max_length': 2048,\n",
" },\n",
"\n",
" # GPT-3.5 and GPT-4 models (cl100k_base)\n",
" 'gpt-3.5-turbo': {\n",
" 'tokenizer_class': 'GPT3_5Tokenizer',\n",
" 'model_max_length': 4096,\n",
" },\n",
" 'gpt-3.5-turbo-16k': {\n",
" 'tokenizer_class': 'GPT3_5Tokenizer',\n",
" 'model_max_length': 16384,\n",
" },\n",
" 'gpt-4': {\n",
" 'tokenizer_class': 'GPT4Tokenizer',\n",
" 'model_max_length': 8192,\n",
" },\n",
" 'text-embedding-ada-002': {\n",
" 'tokenizer_class': 'GPT4Tokenizer',\n",
" 'model_max_length': 8192,\n",
" },\n",
"\n",
" # Codex models (p50k_base)\n",
" 'text-davinci-002': {\n",
" 'tokenizer_class': 'CodexTokenizer',\n",
" 'model_max_length': 4096,\n",
" },\n",
" 'text-davinci-003': {\n",
" 'tokenizer_class': 'CodexTokenizer',\n",
" 'model_max_length': 4096,\n",
" },\n",
"}\n"
],
"metadata": {
"id": "UuNt2kwgFWbN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZypJVeIMFQGQ"
},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"\n",
"import tiktoken\n",
"from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode\n",
"from typing import Dict, Optional\n",
"\n",
"byte_encoder = bytes_to_unicode()\n",
"\n",
"def token_bytes_to_string(b):\n",
" return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])\n",
"\n",
"# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960\n",
"def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> list[bytes]:\n",
" parts = [bytes([b]) for b in token]\n",
" while True:\n",
" min_idx = None\n",
" min_rank = None\n",
" for i, pair in enumerate(zip(parts[:-1], parts[1:])):\n",
" rank = mergeable_ranks.get(pair[0] + pair[1])\n",
" if rank is not None and (min_rank is None or rank < min_rank):\n",
" min_idx = i\n",
" min_rank = rank\n",
" if min_rank is None or (max_rank is not None and min_rank >= max_rank):\n",
" break\n",
" assert min_idx is not None\n",
" parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]\n",
" return parts\n",
"\n",
"def generate_vocab_and_merges(encoder):\n",
" mergeable_ranks = encoder._mergeable_ranks\n",
"\n",
" merges = []\n",
" vocab = {}\n",
" for token, rank in mergeable_ranks.items():\n",
" vocab[token_bytes_to_string(token)] = rank\n",
"\n",
" if len(token) == 1:\n",
" continue\n",
" merged = tuple(bpe(mergeable_ranks, token, max_rank=rank))\n",
" assert len(merged) == 2\n",
"\n",
" merges.append(' '.join(map(token_bytes_to_string, merged)))\n",
"\n",
" # Also add special tokens\n",
" vocab.update(encoder._special_tokens)\n",
"\n",
" return vocab, merges\n",
"\n",
"def convert_tiktoken(model_name, output_dir=None):\n",
" if output_dir is None:\n",
" output_dir = model_name\n",
"\n",
" encoder = tiktoken.encoding_for_model(model_name)\n",
"\n",
" vocab, merges = generate_vocab_and_merges(encoder)\n",
"\n",
" added_tokens = [\n",
" {\n",
" \"id\": id,\n",
" \"content\": content,\n",
" \"single_word\": False,\n",
" \"lstrip\": False,\n",
" \"rstrip\": False,\n",
" \"normalized\": False,\n",
" \"special\": True,\n",
" }\n",
" for content, id in encoder._special_tokens.items()\n",
" ]\n",
"\n",
" # https://huggingface.co/Xenova/gpt2/raw/main/tokenizer_config.json\n",
" tokenizer_config_template = {\n",
" \"add_prefix_space\": False,\n",
" \"bos_token\": \"<|endoftext|>\",\n",
" \"clean_up_tokenization_spaces\": False,\n",
" \"eos_token\": \"<|endoftext|>\",\n",
" \"unk_token\": \"<|endoftext|>\",\n",
" }\n",
" tokenizer_config_template.update(MODEL_INFO[model_name]) # Adds `model_max_length` and `tokenizer_class`\n",
" tokenizer_config_template = dict(sorted(tokenizer_config_template.items(), key=lambda x: x[0]))\n",
"\n",
" os.makedirs(output_dir, exist_ok=True)\n",
"\n",
" if MODEL_INFO[model_name]['tokenizer_class'] in ('GPT3_5Tokenizer', 'GPT4Tokenizer'):\n",
" pre_tokenizer = {\n",
" \"type\": \"Sequence\",\n",
" \"pretokenizers\": [\n",
" {\n",
" \"type\": \"Split\",\n",
" \"pattern\": {\n",
" \"Regex\": \"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\\\r\\\\n\\\\p{L}\\\\p{N}]?\\\\p{L}+|\\\\p{N}{1,3}| ?[^\\\\s\\\\p{L}\\\\p{N}]+[\\\\r\\\\n]*|\\\\s*[\\\\r\\\\n]+|\\\\s+(?!\\\\S)|\\\\s+\"\n",
" },\n",
" \"behavior\": \"Removed\",\n",
" \"invert\": True,\n",
" },\n",
" {\n",
" \"type\": \"ByteLevel\",\n",
" \"add_prefix_space\": False,\n",
" \"trim_offsets\": True,\n",
" \"use_regex\": False,\n",
" }\n",
" ]\n",
" }\n",
" else:\n",
" pre_tokenizer = {\n",
" \"type\": \"ByteLevel\",\n",
" \"add_prefix_space\": False,\n",
" \"trim_offsets\": True,\n",
" \"use_regex\": True,\n",
" }\n",
"\n",
" # https://huggingface.co/Xenova/gpt2/raw/main/tokenizer.json\n",
" tokenizer_template = {\n",
" \"version\": \"1.0\",\n",
" \"truncation\": None,\n",
" \"padding\": None,\n",
" \"added_tokens\": added_tokens,\n",
" \"normalizer\": None,\n",
" \"pre_tokenizer\": pre_tokenizer,\n",
" \"post_processor\": None,\n",
" \"decoder\": {\n",
" \"type\": \"ByteLevel\",\n",
" \"add_prefix_space\": True,\n",
" \"trim_offsets\": True,\n",
" \"use_regex\": True,\n",
" },\n",
" \"model\": {\n",
" \"type\": \"BPE\",\n",
" \"dropout\": None,\n",
" \"unk_token\": None,\n",
" \"continuing_subword_prefix\": \"\",\n",
" \"end_of_word_suffix\": \"\",\n",
" \"fuse_unk\": False,\n",
" \"byte_fallback\": False,\n",
" \"vocab\": vocab,\n",
" \"merges\": merges,\n",
" },\n",
" }\n",
"\n",
"\n",
" # Save to files\n",
" with open(os.path.join(output_dir, 'vocab.json'), 'w', encoding='utf-8') as fp:\n",
" json.dump(vocab, fp, indent=2, ensure_ascii=False)\n",
"\n",
" with open(os.path.join(output_dir, 'tokenizer.json'), 'w', encoding='utf-8') as fp:\n",
" json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False)\n",
"\n",
" with open(os.path.join(output_dir, 'tokenizer_config.json'), 'w', encoding='utf-8') as fp:\n",
" json.dump(tokenizer_config_template, fp, indent=2, ensure_ascii=False)\n",
"\n",
" with open(os.path.join(output_dir, 'special_tokens_map.json'), 'w', encoding='utf-8') as fp:\n",
" json.dump({\n",
" \"bos_token\": \"<|endoftext|>\",\n",
" \"eos_token\": \"<|endoftext|>\",\n",
" \"unk_token\": \"<|endoftext|>\",\n",
" }, fp, indent=2, ensure_ascii=False)\n",
"\n",
" with open(os.path.join(output_dir, 'merges.txt'), 'w', encoding='utf-8') as fp:\n",
" fp.write('#version: 0.2\\n')\n",
" fp.write('\\n'.join(merges))"
]
},
{
"cell_type": "markdown",
"source": [
"### Run conversion"
],
"metadata": {
"id": "wfuFCZRbFMT_"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O87Zz6Vzhb5C"
},
"outputs": [],
"source": [
"output = 'models'\n",
"for model_name in MODEL_INFO:\n",
" convert_tiktoken(model_name, os.path.join(output, model_name))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qx6tfE_UwFNB"
},
"source": [
"### Validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oSRUBMLmwatB"
},
"outputs": [],
"source": [
"# Tests adapted from https://github.com/openai/tiktoken/blob/1b9faf2779855124f05174adf1383e53689ed94b/tests/test_encoding.py\n",
"TESTS = [\n",
" \"\\n\\n\\n\\n\\ns1232\", \"hello world\", \"hello <|endoftext|>\", \"hello world\", \"hello <|endoftext|>\", \"0\", \"00\", \"000\", \"0000\", \"00000\", \"000000\", \"0000000\", \"00000000\", \"000000000\", \"0000000000\", \"00000000000\", \"000000000000\", \"0000000000000\", \"00000000000000\", \"000000000000000\", \"0000000000000000\", \"00000000000000000\", \"rer\", \"'rer\", \"today\\n \", \"today\\n \\n\", \"today\\n \\n\", \"hello world\", \"hello world\", \"hello world\", \" \\x850\", \"\", \"👍\", \" .\",\n",
"]"
]
},
{
"cell_type": "code",
"source": [
"from transformers import GPT2TokenizerFast, logging\n",
"\n",
"# Hide warning messages\n",
"logging.set_verbosity_error()\n",
"\n",
"output = 'models'\n",
"for model_name in MODEL_INFO:\n",
" print('Testing', model_name)\n",
" og_tokenizer = tiktoken.encoding_for_model(model_name)\n",
" hf_tokenizer = GPT2TokenizerFast.from_pretrained(os.path.join(output, model_name))\n",
"\n",
" for test in TESTS:\n",
" # Test encoding\n",
" og_tokens = og_tokenizer.encode(test, allowed_special={'<|endoftext|>'})\n",
" hf_tokens = hf_tokenizer.encode(test)\n",
" assert og_tokens == hf_tokens, f'ENCODE FAIL: \"{test}\". {og_tokens} != {hf_tokens}'\n",
"\n",
" # Test decoding\n",
" og_decoded = og_tokenizer.decode(og_tokens)\n",
" hf_decoded = hf_tokenizer.decode(hf_tokens)\n",
" assert og_decoded == hf_decoded, f'DECODE FAIL: \"{og_tokens}\". {og_decoded} != {hf_decoded}'\n"
],
"metadata": {
"id": "ELyGSJM0-yA4"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@gautierdag
Copy link

gautierdag commented Sep 5, 2023

I think you've got a bug somewhere (probably in the regex):

import tiktoken
from transformers import GPT2TokenizerFast

tst = '\n\n\n\n\ns1232'

# GPT4 test fail
tokenizer_tik = tiktoken.encoding_for_model("gpt-4")
tokenizer = GPT2TokenizerFast.from_pretrained('Xenova/gpt-4')

print([(tokenizer.decode([i]), i) for i in tokenizer.encode(tst)])
# ('\n\n\n\n', 1038), ('\n', 198), ('s', 82), ('12', 717), ('32', 843)
print([(tokenizer_tik.decode([i]), i) for i in tokenizer_tik.encode(tst)])
# ('\n\n\n\n\n', 14963), ('s', 82), ('123', 4513), ('2', 17)

# GPT2 test - success
tokenizer_tik = tiktoken.encoding_for_model("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained('Xenova/gpt2')

print([(tokenizer.decode([i]), i) for i in tokenizer.encode(tst)])
# [('\n\n', 628), ('\n\n', 628), ('\n', 198), ('s', 82), ('12', 1065), ('32', 2624)]
print([(tokenizer_tik.decode([i]), i) for i in tokenizer_tik.encode(tst)])
# [('\n\n', 628), ('\n\n', 628), ('\n', 198), ('s', 82), ('12', 1065), ('32', 2624)]

@xenova
Copy link
Author

xenova commented Sep 5, 2023

Is this a problem with the encoding or decoding? Can you split up your checks please?

Edit: looks like an issue with encoding. Could you still split the tests and send the token ids? thanks! Also, could you check with gpt-2 to see it is still an issue there?

@gautierdag
Copy link

Edited the test and added digits as well. It's the regex splitting I think since GPT-4 / cl100k changes the regex from gpt-3.

@gautierdag
Copy link

gautierdag commented Sep 5, 2023

To fix it, you should change the pre_tokenizer for cl100k_base to:

"pre_tokenizer": {
      "type": "Sequence",
      "pretokenizers": [
        {
          "type": "Split",
          "pattern": {
            "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
          },
          "behavior": "Removed",
          "invert": True
        },
        {
          "type": "ByteLevel",
          "add_prefix_space": False,
          "trim_offsets": True,
          "use_regex": False
        }
      ]
}

When I do that I pass the simple test above.

Also I don't think "post_processor" is needed - you can keep it null/None.

Edit: I've just ran extensive tests with the above "pre_tokenizer" and I confirm it is now equivalent to gpt4.

@xenova
Copy link
Author

xenova commented Sep 5, 2023

Amazing! Thanks so much @gautierdag! I'll update the script accordingly. If possible, could you share what tests you are running? I'd like to update the regex so that it's compatible with both python and javascript, but the ?i: seems to break in javascript.

(see here for my JS demo)

@gautierdag
Copy link

gautierdag commented Sep 6, 2023

I can't share internal tests unfortunately, I just ran both tokenizers on a large dataset of various different text types and confirmed they were the same.

Yeah the ?i: regex requires special handling, in python you'd need to use the regex library (not re) to be able to reproduce the regex's functionality. Here the python bindings use the underlying rust regex crate which seems to work (though I think technically the fancy-regex crate should be used instead). Not sure about JS, sorry 😞 !

@xenova
Copy link
Author

xenova commented Sep 6, 2023

I can't share internal tests unfortunately, I just ran both tokenizers on a large dataset of various different text types and confirmed they were the same.

No worries!

Yeah the ?i: regex requires special handling, in python you'd need to use the regex library (not re) to be able to reproduce the regex's functionality.

Right, I noticed that while playing around with it a bit more yesterday. I suppose the entire regex can be set to case-insensitive mode, no? Do you notice any difference in your tests if ?i: is removed, but the entire regex is set to case-insensitive? (as opposed to that first group)?

@gautierdag
Copy link

mmmh I changed the regex to:
"(?i)'s|'t|'re|'ve|'m|'ll|'d|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"

Which I think should be setting the whole expression case-insensitive. However this breaks on a weird UTF character (https://unicode-explorer.com/c/0345). I sadly don't have an explanation for it 🤷

character = 'ͅ' # U+0345 fails

t = hfgpt4.encode(character)
print(t) # [137, 227] - what tiktoken also returns

o = hfgpt4_case_insensitive.encode(character)
print(o) # []

It is equivalent otherwise for everything else I tested.

@binxuan
Copy link

binxuan commented Sep 20, 2023

Hey, thanks for sharing the solution and discussion! Do we have any conclusion on which Regex to use to fully replicate the tiktoken in Huggingface? Is this pre_tokenizer setting working? Does removing post_processor yield different results?

```python
"pre_tokenizer": {
      "type": "Sequence",
      "pretokenizers": [
        {
          "type": "Split",
          "pattern": {
            "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
          },
          "behavior": "Removed",
          "invert": True
        },
        {
          "type": "ByteLevel",
          "add_prefix_space": False,
          "trim_offsets": True,
          "use_regex": False
        }
      ]
}

@gautierdag
Copy link

Hey, thanks for sharing the solution and discussion! Do we have any conclusion on which Regex to use to fully replicate the tiktoken in Huggingface? Is this pre_tokenizer setting working? Does removing post_processor yield different results?

```python
"pre_tokenizer": {
      "type": "Sequence",
      "pretokenizers": [
        {
          "type": "Split",
          "pattern": {
            "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
          },
          "behavior": "Removed",
          "invert": True
        },
        {
          "type": "ByteLevel",
          "add_prefix_space": False,
          "trim_offsets": True,
          "use_regex": False
        }
      ]
}

Try it :)

Removing the post_processor shouldn't do anything since the decoder already handles byte level decoding. And as long as you adapted the Regex to the gpt4 version then it should work in python. I can't speak for JS.

@binxuan
Copy link

binxuan commented Sep 22, 2023

Yeah I tried on 10M cases and only got 2 unmatched token sequences, which should be fine for common use cases. By the way, I noticed that the decoding sometimes does not yield same results between converted HF and TikToken. For example, I got different text string from this token sequence [12906, 224, 61196, 5619, 248, 15272, 113, 55884, 245, 5619, 111, 73414, 13, 15272, 250, 31584, 107, 24810, 15272, 113, 31584, 107, 65804, 31584, 97, 43411, 105, 5619, 99, 31584, 99, 92911, 15272, 228, 5619, 250, 84736, 86133, 80338, 31584, 107, 55884, 243, 32511, 248, 31584, 107, 24810, 92317, 61196, 32511, 97, 15272, 246, 12906, 225, 5619, 96, 24810, 11, 15272, 248, 44747, 5619, 94, 1174, 15272, 108, 32511, 245, 11, 15272, 99, 31584, 113, 55884, 115, 11, 15272, 107, 24810, 15272, 255, 32511, 113, 61196, 32511, 224, 5619, 244, 35470, 45279, 44747, 5619, 250, 48909, 32511, 117, 44747, 15272, 101, 32511, 117, 44747, 11, 15272, 107, 24810, 84736, 86133, 32511, 108, 31584, 114, 31584, 113, 5619, 255, 12906, 224, 88344, 44747, 5619, 113, 45279, 15272, 97, 31584, 107, 24810, 15272, 113, 31584, 107, 65804, 31584, 97, 44747, 5619, 248, 44747, 15272, 105, 32511, 107, 65804, 55675, 15272, 228, 5619, 96, 39951, 92317, 73753, 92911, 32511, 101, 35470, 85410, 35470, 84736, 73753, 79468, 31584, 97, 65804, 15272, 110, 43411, 117, 5619, 96, 31584, 107, 32511, 97, 85410, 24810, 84736, 73753, 5619, 95, 32511, 243, 32511, 108, 15272, 246, 31584, 107, 32511, 113, 24810, 11, 15272, 97, 31584, 107, 32511, 248, 73414, 15272, 228, 5619, 107, 73753, 5619, 115, 31584, 107, 15272, 110, 55675, 65804, 32511, 224, 79468, 88344, 55675, 45279, 92317, 32511, 224, 5619, 94, 5619, 96, 31584, 107, 32511, 248, 24810, 84736, 86133, 5619, 107, 80338, 31584, 101, 48909, 45279, 32511, 113, 24810, 11, 85410, 55884, 248, 15272, 113, 43411, 114, 55884, 115, 15272, 228, 95048, 35470, 13, 15272, 228, 5619, 96, 39951, 15272, 241, 79468, 32511, 106, 32511, 248, 24810, 15272, 114, 55884, 113, 5619, 253, 15272, 251, 32511, 110, 31584, 107, 32511, 101, 73414, 80338, 45279, 15272, 227, 92911, 31584, 103, 32511, 113, 5619, 100, 44747, 80338, 5619, 248, 85410, 35470, 84736, 73753, 79468, 31584, 97, 65804]

@binxuan
Copy link

binxuan commented Sep 22, 2023

Nvm, I found it is caused by setting clean_up_tokenization_spaces=True.

@KerfuffleV2
Copy link

@xenova Thanks for posting this! For the purposes of adapting/incorporating into other projects, what's the license for this code? (Maybe add a note license info to the comments at the top?)

@xenova
Copy link
Author

xenova commented Feb 5, 2024

Just an update on the issue with the case-insensitive group modifier (?i:), which causes issues with certain regex implementations (e.g., JS): I think it's reasonable to just replace the problematic section with a longer (but equivalent) version.

Original: (?i:'s|'t|'re|'ve|'m|'ll|'d)|

JS-friendly version: (?:'([sS]|[tT]|[rR][eE]|[vV][eE]|[mM]|[lL][lL]|[dD]))

For the purposes of adapting/incorporating into other projects, what's the license for this code?

Do what you want with it :) In any case, my code is adapted from this comment, with a few modifications.

@xenova
Copy link
Author

xenova commented Mar 27, 2024

I actually forgot to update the gist with my new conversion script, which takes into account the new split pretokenization regex (thanks @gautierdag for pointing that out!).

It also sets the default clean_up_tokenization_spaces to False (thanks @binxuan for pointing that out).

So, now it's updated 🤗 👍 I've also validated the GPT-4 tokenizer on the entire XNLI dataset (all languages) with 100% compatibility (both encoding and decoding). 🔥 Code to validate:

import tqdm
from datasets import load_dataset
import tiktoken
from transformers import GPT2TokenizerFast

hf_tokenizer = GPT2TokenizerFast.from_pretrained('Xenova/gpt-4')
og_tokenizer = tiktoken.encoding_for_model('gpt-4')

dataset = load_dataset('xnli', 'all_languages')

for item in tqdm.tqdm(dataset['train']):
    for string in item['premise'].values():
        encoded1 = og_tokenizer.encode(string)
        encoded2 = hf_tokenizer.encode(string)

        assert encoded1 == encoded2, f'encoding "{string}" is incorrect. "{encoded1}" != "{encoded2}"'

        decoded1 = og_tokenizer.decode(encoded1)
        decoded2 = hf_tokenizer.decode(encoded2, skip_special_tokens=True)

        assert decoded1 == decoded2, f'decoding "{string}" is incorrect. "{decoded1}" != "{decoded2}"'

@david-waterworth
Copy link

Shouldn't 'tokenizer_class' be 'GPT2Tokenizer' in all cases? This is the huggingface concrete class that's instantiated - i.e. by doing this you can use

 hf_tokenizer = AutoTokenizer.from_pretrained('Xenova/gpt-4')

Rather than GPT2TokenizerFast (which then generates a warning).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment