Skip to content

Instantly share code, notes, and snippets.

@tarekgh
Last active January 25, 2024 22:21
Show Gist options
  • Save tarekgh/af25a86f0b815eb401e4a68cb908896e to your computer and use it in GitHub Desktop.
Save tarekgh/af25a86f0b815eb401e4a68cb908896e to your computer and use it in GitHub Desktop.
Tiktoken Tokenizer Proposal

Tiktoken Tokenizer Proposal

This document outlines the proposal for integrating the Tiktoken Tokenizer into ML.NET. ML.NET currently features a tokenizers library for text, catering to tokenization needs for NLP tasks. Incorporating support for Tiktoken would be a valuable addition to the library, enhancing its capabilities to support AI models like GPT-4.

Usage Example

    Tokenizer tokenizer = await Tokenizer.CreateByModelNameAsync("gpt-4");

    // Encoding to Ids
    string text = "Hello World";
    IReadOnlyList<int> encoded = tokenizer.EncodeToIds(text);
    Assert.Equal(new List<int>() { 9906, 4435 }, encoded);
    Assert.Equal(text, tokenizer. Decode(encoded)!);

    // Full encoding to tokens, Ids, and offsets
    TokenizerResult result = tokenizer.Encode(text);
    Assert.Equal(new List<int>() { 9906, 4435 }, result.Ids);
    Assert.Equal(new string[] { "Hello", " World" }, result.Tokens);
    Assert.Equal(new List<(int, int)> { (0, 5), (5, 11) }, result.Offsets);

I've developed a prototype porting the Microsoft Tokenizer library implementation. I've taken into account how this tokenizer aligns with the design of ML.NET tokenizers library while ensuring performance. Here are benchmark figures from the prototype, with the Microsoft Tokenizer library serving as the baseline for the primary functionality, which involves encoding text into IDs, taking special tokens into consideration:


BenchmarkDotNet v0.13.12, Windows 11 (10.0.22631.3085/23H2/2023Update/SunValley3)
11th Gen Intel Core i7-11700 2.50GHz, 1 CPU, 16 logical and 8 physical cores
.NET SDK 8.0.200-preview.23624.5
  [Host]     : .NET 8.0.1 (8.0.123.58001), X64 RyuJIT AVX-512F+CD+BW+DQ+VL+VBMI
  DefaultJob : .NET 8.0.1 (8.0.123.58001), X64 RyuJIT AVX-512F+CD+BW+DQ+VL+VBMI


| Method                          | Mean       | Error     | StdDev   | Gen0   | Gen1   | Allocated |
|-------------------------------- |-----------:|----------:|---------:|-------:|-------:|----------:|
| MSLibEncodeWithSpecialTokens    | 5,887.2 ns |  32.68 ns | 27.29 ns | 1.6861 | 0.0229 |   13.8 KB |
| MLEncodeWithSpecialTokens       | 4,680.5 ns |  45.89 ns | 42.93 ns | 1.4343 | 0.0229 |  11.77 KB |
| MSLibEncodeWithOutSpecialTokens | 5,483.8 ns | 104.65 ns | 92.77 ns | 1.6861 | 0.0229 |   13.8 KB |
| MLEncodeWithOutSpecialTokens    | 4,745.9 ns |  42.38 ns | 39.64 ns | 1.4343 | 0.0229 |  11.77 KB |

MSLib prefix is referring to Microsoft Tokenizer Library and ML is referring to ML Tokenizers library.

API Proposal

namespace Microsoft.ML.Tokenizers
{
    public class Tokenizer
    {
+        /// <summary>
+        /// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping.
+        /// </summary>
+        /// <param name="sequence">The text to tokenize.</param>
+        /// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
+        /// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
+        public TokenizerResult Encode(string sequence, bool skipSpecialTokens); // overload adding skipSpecialTokens parameter.

+        /// <summary>
+        /// Encodes input text to tokens Ids.
+        /// </summary>
+        /// <param name="sequence">The text to tokenize.</param>
+        /// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
+        /// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
+        public IReadOnlyList<int> EncodeToIds(string sequence, bool skipSpecialTokens = false);

+        /// <summary>
+        /// Create tokenizer based on model name
+        /// </summary>
+        /// <param name="modelName">Model name</param>
+        /// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
+        /// <param name="normalizer">To normalize the text before tokenization</param>
+        /// <returns>The tokenizer</returns>
+        public static async Task<Tokenizer> CreateByModelNameAsync(
+                                                string modelName,
+                                                IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
+                                                Normalizer? normalizer = null)
    }

-    public class Split : IEquatable<Split>
+    public readonly struct Split : IEquatable<Split>
     {
-        public Split(string token, (int Index, int End) offset)
+        public Split(string token, (int Index, int End) offset, bool isSpecialToken = false)

+        /// <summary>
+        /// Gets if the current Split is a special token.
+        /// </summary>
+        public bool IsSpecialToken { get; }
    }

    public abstract class PreTokenizer
    {
+        // Primarily focused on optimizing to minimize memory allocations and enable the enumeration of one item at a time,
+        // rather than holding a large list in a collection.
+        // This change will reflect in all public classes which implementing this interface.
-        public abstract IReadOnlyLIst<Split> PreTokenize(string sentence);
+        public abstract IEnumerable<Split> PreTokenize(string sentence, bool skipSpecialTokens = false);
    }

    public sealed class TokenizerResult
    {
-        public TokenizerResult(string originalString, string normalizedString, IReadOnlyList<Split> splits, bool offsetsMappedToOriginalString);
+        public TokenizerResult(string originalString, string normalizedString, IEnumerable<Split> splits, bool offsetsMappedToOriginalString);
    }


    public abstract class Model
    {
+        public virtual IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken); // overload to add isSpecialToken parameter.

+        public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, List<int> accumulatedIds); // To be consumed by Tokenizer.EncodeToIds

+        public virtual int? TokenToId(string token, bool skipSpecialTokens); // overload to add isSpecialToken parameter.
   }


+    public sealed class Tiktoken : Model
+    {
+        public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary<string, int>? specialTokensEncoder = null, int cacheSize = DefaultCacheSize);
+        public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>? specialTokensEncoder = null, int cacheSize = DefaultCacheSize);

+        public IReadOnlyDictionary<string, int>? SpecialTokens { get; }

+        // Implement the Model abstract methods
+    }

+   public sealed class TikTokenPreTokenizer : PreTokenizer
+   {
+       public TikTokenPreTokenizer(string regexPattern, IReadOnlyDictionary<string, int>? specialTokensEncoder);

+       // Implement the Model abstract methods
+   }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment