Skip to content

Instantly share code, notes, and snippets.

@jneuff
Created October 17, 2023 11:35
Show Gist options
  • Save jneuff/682d47b786329f19291d166957b3274a to your computer and use it in GitHub Desktop.
Save jneuff/682d47b786329f19291d166957b3274a to your computer and use it in GitHub Desktop.
Fix a huggingface tokenizer to which tokens have been added after training
/// Fix a huggingface tokenizer to which tokens have been added after training.
///
/// Adding tokens after training via `add_special_tokens` leads to them being added to the
/// `added_tokens` section but not to the `model.vocab` section. This yields warnings like:
/// ```
/// [2023-10-17T07:54:05Z WARN tokenizers::tokenizer::serialization] Warning: Token '<|empty_usable_token_space_1023|>' was expected to have ID '129023' but was given ID 'None'
/// ```
/// The code in this file ensures that all tokens from `added_tokens` are also placed into
/// `model.vocab`. This fixes the warning and does not change the tokenizer's behavior.
use std::collections::HashMap;
use serde_json::Value;
use serde::Deserialize;
#[derive(Deserialize)]
struct AddedToken {
id: usize,
content: String,
}
fn main() {
let raw = std::fs::read("./source.json").unwrap();
let mut tokenizer: Value = serde_json::from_slice(&raw).unwrap();
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap();
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap();
for token in added_tokens {
if !vocab.contains_key(&token.content) {
tokenizer["model"]["vocab"].as_object_mut().unwrap().insert(token.content, token.id.into()).ok_or(()).unwrap_err();
}
}
let raw_fixed = serde_json::to_vec_pretty(&tokenizer).unwrap();
std::fs::write("./fixed.json", raw_fixed).unwrap();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_have_expected_diff_in_length_before() {
let raw = std::fs::read("./source.json").unwrap();
let tokenizer: Value = serde_json::from_slice(&raw).unwrap();
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap();
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap();
let tokens_not_in_vocab: Vec<&AddedToken> = added_tokens.iter().filter(|&t| !vocab.contains_key(&t.content)).collect();
assert_eq!(tokens_not_in_vocab.len(), 1024);
}
#[test]
fn should_have_expected_diff_in_length_after() {
let raw = std::fs::read("./fixed.json").unwrap();
let tokenizer: Value = serde_json::from_slice(&raw).unwrap();
let added_tokens: Vec<AddedToken> = serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap();
let vocab: HashMap<String, usize> = serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap();
let tokens_not_in_vocab: Vec<&AddedToken> = added_tokens.iter().filter(|&t| !vocab.contains_key(&t.content)).collect();
assert_eq!(tokens_not_in_vocab.len(), 0);
}
}
@EricLBuehler
Copy link

@jneuff thank you so much for adding this! Refs EricLBuehler/mistral.rs#314.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment