Last active
December 8, 2021 11:08
-
-
Save xunge/51935912849b2b36274d44fa7deef205 to your computer and use it in GitHub Desktop.
Whitespace tokenizer for training Roberta from scratch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tokenizers import Tokenizer | |
from tokenizers.models import WordLevel | |
from tokenizers.trainers import WordLevelTrainer | |
from tokenizers.pre_tokenizers import WhitespaceSplit | |
from tokenizers.processors import RobertaProcessing | |
from model.roberta import RobertaTokenizerFast | |
SPECIAL_TOKENS = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"] | |
UNK_TOKENS = "<unk>" | |
def prepare_tokenizer_trainer(): | |
""" | |
Prepares the tokenizer and trainer with unknown & special tokens. | |
""" | |
tokenizer = Tokenizer(WordLevel(unk_token=UNK_TOKENS)) | |
trainer = WordLevelTrainer(special_tokens=SPECIAL_TOKENS) | |
tokenizer.pre_tokenizer = WhitespaceSplit() | |
return tokenizer, trainer | |
def train_tokenizer(files): | |
""" | |
Takes the files and trains the tokenizer. | |
""" | |
tokenizer, trainer = prepare_tokenizer_trainer() | |
tokenizer.train(files, trainer) # training the tokenzier | |
# tokenizer.add_special_tokens(SPECIAL_TOKENS) | |
tokenizer.post_processor = RobertaProcessing( | |
sep=("</s>", tokenizer.token_to_id("</s>")), | |
cls=("<s>", tokenizer.token_to_id("<s>")), | |
) | |
tokenizer.save("./asm_roberta_tokenizer/tokenizer.json") | |
tokenizer = RobertaTokenizerFast.from_pretrained("./asm_roberta_tokenizer") | |
# tokenizer = Tokenizer.from_file("./asm_roberta_tokenizer/tokenizer.json") | |
return tokenizer | |
if __name__ == '__main__': | |
files = ['data/x86-64.txt'] | |
trained_tokenizer = train_tokenizer(files) | |
input_string1 = "push rax mov rcx , qword ptr [ rdi ] mov rdx , qword ptr [ rdi + num ] " \ | |
"mov r8 , qword ptr [ rsi + num ] mov r9d , dword ptr [ rsi ] xor edi , edi xor esi , esi " | |
input_string2 = "push r10 add r10 , num push r10 call var add rsp , var mov ecx , eax test eax , eax " \ | |
"mov eax , var cmove eax , ecx add rsp , num" | |
# output_encode = trained_tokenizer.encode(input_string1) | |
output_encode = trained_tokenizer.encode(input_string1, input_string2) | |
tokens_result = trained_tokenizer.decode(output_encode) | |
# tokens_result = output_encode.tokens | |
print(tokens_result, "->", len(tokens_result)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment