Skip to content

Instantly share code, notes, and snippets.

@tarekgh
Created February 11, 2024 23:18
Show Gist options
  • Save tarekgh/341f0b969dcd4ef205b6149b5e7a74e8 to your computer and use it in GitHub Desktop.
Save tarekgh/341f0b969dcd4ef205b6149b5e7a74e8 to your computer and use it in GitHub Desktop.
Tokenizer Interface Change Proposal

Tokenizer Interface Changes Proposal

This document capturing the thoughts and ideas for the Tokenizer interface changes. The goal is to add or change the APIs that allow the callers to have more control over the memory allocation and performance. The Tokenizer class is the main class used for tokenization and that will be used by the model to encode and decode the input and output. When getting the right shape of this class, we can know the exact other changes will need to be done in the other main interfaces like the Model and PreTokenizer.

    public class Tokenizer
    {
        public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null) { }
        public Model Model { get; }
        public PreTokenizer PreTokenizer { get; set; }
        public Normalizer? Normalizer { get; set; }
        public TokenizerDecoder? Decoder { get; set; }

-       // This can be removed as the other overload already covering it
-       public TokenizerResult Encode(string sequence) { return default; }

-       // While this operation is computationally expensive, it is necessary when callers require
-       // complete information, including token type and offsets.
-       // We could explore exposing an additional overload that returns IEnumerable<Token> instead of
-       // TokenizerResult. Implementing this would be challenging, but we can consider it if there's demand.
        public TokenizerResult Encode(string sequence, bool skipSpecialTokens = false) { return default; }

+       // There is some idea to add Encode overload which can work on
+       // Span<(int id, int tokenOffset, int TokenLength)>. The issue is most to tokenizers already create
+       // the Token string which may not reflect any optimization or memory allocation saving.
+       // We may look more into this in the future.

        public IReadOnlyList<int> EncodeToIds(string sequence, bool skipSpecialTokens = false) { return default; }

+       // This stores the Ids in a span, giving the callers more control over the memory allocation.
+       // The sequence is still passed as a string because most pre-tokenizers rely on Regex,
+       // which requires a string input.
+       // If Regex is improved to support Span<char> in the future, another overload with ReadOnlySpan<char>
+       // as a parameter could be added.
+       public bool EncodeToIds(string sequence, ReadOnlySpan<int> Ids, out int writtenIds, out int processedCharCount, bool skipSpecialTokens = false) { writtenIds = 0; processedCharCount = 0; return false; }

+       // This is a new method that returns the number of encoded IDs.
+       public int GetEncodedIdsCount(string sequence) { return default; }

        public string? Decode(int id, bool skipSpecialTokens = false) { return default; }
        public string? Decode(IEnumerable<int> ids, bool skipSpecialTokens = false) { return default; }

+       // This method will return the number of written chars or negative number if the buffer is not enough.
+       // Note, will be hard to make this method work returning how far it processed the input because some tokenizers
+       // like Tiktoken can map multiple Ids to a single char or word, knowing the boundaries will be challenging.
+       public int Decode(IEnumerable<int> ids, Span<char>, bool skipSpecialTokens = false) { return default; }

        public void TrainFromFiles(Trainer? trainer, ReportProgress? progress, params string[] files) { }

-        // Is not directly specific belong to here and used as a hack for one case Bert model.
-        // Can be done in different way without exposing this method.
-        public bool IsValidChar(char ch) { return false; }

        public static async Task<Tokenizer> CreateByModelNameAsync(string modelName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null) { return default; }
    }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment