Created
October 17, 2023 11:35
-
-
Save jneuff/682d47b786329f19291d166957b3274a to your computer and use it in GitHub Desktop.
Fix a huggingface tokenizer to which tokens have been added after training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Fix a huggingface tokenizer to which tokens have been added after training. | |
/// | |
/// Adding tokens after training via `add_special_tokens` leads to them being added to the | |
/// `added_tokens` section but not to the `model.vocab` section. This yields warnings like: | |
/// ``` | |
/// [2023-10-17T07:54:05Z WARN tokenizers::tokenizer::serialization] Warning: Token '<|empty_usable_token_space_1023|>' was expected to have ID '129023' but was given ID 'None' | |
/// ``` | |
/// The code in this file ensures that all tokens from `added_tokens` are also placed into | |
/// `model.vocab`. This fixes the warning and does not change the tokenizer's behavior. | |
use std::collections::HashMap; | |
use serde_json::Value; | |
use serde::Deserialize; | |
#[derive(Deserialize)] | |
struct AddedToken { | |
id: usize, | |
content: String, | |
} | |
fn main() { | |
let raw = std::fs::read("./source.json").unwrap(); | |
let mut tokenizer: Value = serde_json::from_slice(&raw).unwrap(); | |
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap(); | |
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap(); | |
for token in added_tokens { | |
if !vocab.contains_key(&token.content) { | |
tokenizer["model"]["vocab"].as_object_mut().unwrap().insert(token.content, token.id.into()).ok_or(()).unwrap_err(); | |
} | |
} | |
let raw_fixed = serde_json::to_vec_pretty(&tokenizer).unwrap(); | |
std::fs::write("./fixed.json", raw_fixed).unwrap(); | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
fn should_have_expected_diff_in_length_before() { | |
let raw = std::fs::read("./source.json").unwrap(); | |
let tokenizer: Value = serde_json::from_slice(&raw).unwrap(); | |
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap(); | |
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap(); | |
let tokens_not_in_vocab: Vec<&AddedToken> = added_tokens.iter().filter(|&t| !vocab.contains_key(&t.content)).collect(); | |
assert_eq!(tokens_not_in_vocab.len(), 1024); | |
} | |
#[test] | |
fn should_have_expected_diff_in_length_after() { | |
let raw = std::fs::read("./fixed.json").unwrap(); | |
let tokenizer: Value = serde_json::from_slice(&raw).unwrap(); | |
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap(); | |
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap(); | |
let tokens_not_in_vocab: Vec<&AddedToken> = added_tokens.iter().filter(|&t| !vocab.contains_key(&t.content)).collect(); | |
assert_eq!(tokens_not_in_vocab.len(), 0); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@jneuff thank you so much for adding this! Refs EricLBuehler/mistral.rs#314.