Skip to content

Instantly share code, notes, and snippets.

@xunge
Last active December 8, 2021 11:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xunge/51935912849b2b36274d44fa7deef205 to your computer and use it in GitHub Desktop.
Save xunge/51935912849b2b36274d44fa7deef205 to your computer and use it in GitHub Desktop.
Whitespace tokenizer for training Roberta from scratch
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.processors import RobertaProcessing
from model.roberta import RobertaTokenizerFast
SPECIAL_TOKENS = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
UNK_TOKENS = "<unk>"
def prepare_tokenizer_trainer():
"""
Prepares the tokenizer and trainer with unknown & special tokens.
"""
tokenizer = Tokenizer(WordLevel(unk_token=UNK_TOKENS))
trainer = WordLevelTrainer(special_tokens=SPECIAL_TOKENS)
tokenizer.pre_tokenizer = WhitespaceSplit()
return tokenizer, trainer
def train_tokenizer(files):
"""
Takes the files and trains the tokenizer.
"""
tokenizer, trainer = prepare_tokenizer_trainer()
tokenizer.train(files, trainer) # training the tokenzier
# tokenizer.add_special_tokens(SPECIAL_TOKENS)
tokenizer.post_processor = RobertaProcessing(
sep=("</s>", tokenizer.token_to_id("</s>")),
cls=("<s>", tokenizer.token_to_id("<s>")),
)
tokenizer.save("./asm_roberta_tokenizer/tokenizer.json")
tokenizer = RobertaTokenizerFast.from_pretrained("./asm_roberta_tokenizer")
# tokenizer = Tokenizer.from_file("./asm_roberta_tokenizer/tokenizer.json")
return tokenizer
if __name__ == '__main__':
files = ['data/x86-64.txt']
trained_tokenizer = train_tokenizer(files)
input_string1 = "push rax mov rcx , qword ptr [ rdi ] mov rdx , qword ptr [ rdi + num ] " \
"mov r8 , qword ptr [ rsi + num ] mov r9d , dword ptr [ rsi ] xor edi , edi xor esi , esi "
input_string2 = "push r10 add r10 , num push r10 call var add rsp , var mov ecx , eax test eax , eax " \
"mov eax , var cmove eax , ecx add rsp , num"
# output_encode = trained_tokenizer.encode(input_string1)
output_encode = trained_tokenizer.encode(input_string1, input_string2)
tokens_result = trained_tokenizer.decode(output_encode)
# tokens_result = output_encode.tokens
print(tokens_result, "->", len(tokens_result))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment