Last active
June 17, 2024 14:03
Fix tokenizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# Source: https://gist.github.com/jneuff/682d47b786329f19291d166957b3274a | |
/// Fix a huggingface tokenizer to which tokens have been added after training. | |
/// | |
/// Adding tokens after training via `add_special_tokens` leads to them being added to the | |
/// `added_tokens` section but not to the `model.vocab` section. This yields warnings like: | |
/// ``` | |
/// [2023-10-17T07:54:05Z WARN tokenizers::tokenizer::serialization] Warning: Token '<|empty_usable_token_space_1023|>' was expected to have ID '129023' but was given ID 'None' | |
/// ``` | |
/// The code in this file ensures that all tokens from `added_tokens` are also placed into | |
/// `model.vocab`. This fixes the warning and does not change the tokenizer's behavior. | |
use std::collections::HashMap; | |
use serde_json::Value; | |
use serde::Deserialize; | |
#[derive(Deserialize)] | |
struct AddedToken { | |
id: usize, | |
content: String, | |
} | |
fn main() { | |
let raw = std::fs::read("./source.json").unwrap(); | |
let mut tokenizer: Value = serde_json::from_slice(&raw).unwrap(); | |
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap(); | |
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap(); | |
for token in added_tokens { | |
if !vocab.contains_key(&token.content) { | |
tokenizer["model"]["vocab"].as_object_mut().unwrap().insert(token.content, token.id.into()).ok_or(()).unwrap_err(); | |
} | |
} | |
let raw_fixed = serde_json::to_vec_pretty(&tokenizer).unwrap(); | |
std::fs::write("./fixed.json", raw_fixed).unwrap(); | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
fn should_have_expected_diff_in_length_before() { | |
let raw = std::fs::read("./source.json").unwrap(); | |
let tokenizer: Value = serde_json::from_slice(&raw).unwrap(); | |
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap(); | |
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap(); | |
let tokens_not_in_vocab: Vec<&AddedToken> = added_tokens.iter().filter(|&t| !vocab.contains_key(&t.content)).collect(); | |
assert_eq!(tokens_not_in_vocab.len(), 1024); | |
} | |
#[test] | |
fn should_have_expected_diff_in_length_after() { | |
let raw = std::fs::read("./fixed.json").unwrap(); | |
let tokenizer: Value = serde_json::from_slice(&raw).unwrap(); | |
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap(); | |
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap(); | |
let tokens_not_in_vocab: Vec<&AddedToken> = added_tokens.iter().filter(|&t| !vocab.contains_key(&t.content)).collect(); | |
assert_eq!(tokens_not_in_vocab.len(), 0); | |
} | |
} | |
""" | |
# Python implementation of above | |
# Get tokenizer config for the model first, e.g. from https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/blob/main/tokenizer.json | |
import json | |
class AddedToken: | |
def __init__( | |
self, | |
id: int, | |
content: str, | |
single_word: bool, | |
lstrip: bool, | |
rstrip: bool, | |
normalized: bool, | |
special: bool, | |
): | |
self.id = id | |
self.content = content | |
self.single_word = single_word | |
self.lstrip = lstrip | |
self.rstrip = rstrip | |
self.normalized = normalized | |
self.special = special | |
def load_json(filename): | |
with open(filename, 'r') as file: | |
return json.load(file) | |
def save_json(data, filename): | |
with open(filename, 'w') as file: | |
json.dump(data, file, indent=4) | |
def fix_tokenizer(): | |
tokenizer = load_json("./tokenizer.json") | |
added_tokens = [AddedToken(**token) for token in tokenizer["added_tokens"]] | |
vocab = tokenizer["model"]["vocab"] | |
for token in added_tokens: | |
if token.content not in vocab: | |
vocab[token.content] = token.id | |
save_json(tokenizer, "./tokenizer.fixed.json") | |
# Some simple sanity checks | |
def run_test_tokens_not_in_vocab_before(): | |
tokenizer = load_json("./tokenizer.json") | |
added_tokens = [AddedToken(**token) for token in tokenizer["added_tokens"]] | |
vocab = tokenizer["model"]["vocab"] | |
tokens_not_in_vocab = [token for token in added_tokens if token.content not in vocab] | |
print(len(tokens_not_in_vocab), 1024) | |
def run_test_tokens_not_in_vocab_after(): | |
tokenizer = load_json("./tokenizer.fixed.json") | |
added_tokens = [AddedToken(**token) for token in tokenizer["added_tokens"]] | |
vocab = tokenizer["model"]["vocab"] | |
tokens_not_in_vocab = [token for token in added_tokens if token.content not in vocab] | |
print(len(tokens_not_in_vocab), 0) | |
fix_tokenizer() | |
run_test_tokens_not_in_vocab_before() | |
run_test_tokens_not_in_vocab_after() | |
# To apply patch, first generate diff file | |
# diff -u tokenizer.json tokenizer.fixed.json > tokenizer.diff | |
# Then apply the changes from the diff to tokenizer.json | |
# patch tokenizer.json < tokenizer.diff |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment