Skip to content

Instantly share code, notes, and snippets.

@NegatioN
Created May 12, 2022 13:26
Show Gist options
  • Save NegatioN/c2fdac25e9137cbf21021dedab99d801 to your computer and use it in GitHub Desktop.
Save NegatioN/c2fdac25e9137cbf21021dedab99d801 to your computer and use it in GitHub Desktop.
A Torchscript-able MiniLM Tokenizer
class MiniLMTokenizer(torch.nn.Module):
def __init__(self, vocab: Dict[str, int], vocab_scores: Dict[str, float]):
'''
:param vocab: A dictionary mapping from string to index of token
:param vocab_scores: A dictionary mapping from string to the score of a given token. This is used to decide
which tokenization is most probable for our input string.
For unigram models this should be avaiable under `~/.cache/torch/sentence_transformers/$YOURMODEL/unigram.json`
You might also need to flip the scores if they're negative.
```
with open(f'~/.cache/torch/sentence_transformers/{org}_{model_name}/unigram.json', 'r') as f:
unigram_scores = json.load(f)
ug_scores = {k: -v for (k,v) in unigram_scores['vocab']}
```
'''
super(MiniLMTokenizer, self).__init__()
self.vocab, self.vocab_scores = vocab, vocab_scores
self.UNK_TOKEN = '[unk]'
self.word_start_token = '▁'
self.START_ID, self.END_ID = 0, 2
self.uninitialized_value: float = -1. # Use score of -1 to indicate "uninitialized", since optionals are clunky.
def tokenize(self, inp: str) -> List[str]:
all_tokens: List[str] = []
for word in inp.split(" "):
all_tokens.extend(self.encode_word(word))
return all_tokens
def simple_numberize(self, tokens: List[str], max_len:int=20) -> List[int]:
return [self.START_ID]+[self.vocab.get(x, -1) for x in tokens][:(max_len-2)]+[self.END_ID]
# Using slightly modified Huggingface reference implementation in Python at https://huggingface.co/course/chapter6/7
def encode_word(self, word: str) -> List[str]:
word = self.word_start_token + word
best_segmentations: List[Dict[str, float]] = [{"start": 0., "score": self.uninitialized_value} for _ in range(len(word)+1)]
best_segmentations[0]['score'] = 1.
for start_idx in range(len(word)):
# This should be properly filled by the previous steps of the loop
best_score_at_start = best_segmentations[start_idx]["score"]
for end_idx in range(start_idx + 1, len(word) + 1):
token = word[start_idx:end_idx]
if token in self.vocab_scores and best_score_at_start != self.uninitialized_value:
score = self.vocab_scores[token] + best_score_at_start
# If we have found a better segmentation ending at end_idx, we update
end_segment = best_segmentations[end_idx]
if (end_segment["score"] == self.uninitialized_value or end_segment["score"] > score):
new_segment: Dict[str, float] = {"start": float(start_idx), "score": score}
best_segmentations[end_idx] = new_segment
segmentation = best_segmentations[-1]
if segmentation["score"] == self.uninitialized_value:
return [self.UNK_TOKEN]
score, start, end = segmentation["score"], int(segmentation["start"]), len(word)
tokens: List[str] = []
while start != 0:
tokens.insert(0, word[start:end])
start, end = int(best_segmentations[start]["start"]), start
tokens.insert(0, word[start:end])
return tokens
def forward(self, sentence: str, max_len: int = 20):
res = self.tokenize(sentence)
return self.simple_numberize(res, max_len)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment