Skip to content

Instantly share code, notes, and snippets.

@gautierdag
Created May 23, 2022 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautierdag/c6dc5c53d731f0d5f28654ff4c94698a to your computer and use it in GitHub Desktop.
Save gautierdag/c6dc5c53d731f0d5f28654ff4c94698a to your computer and use it in GitHub Desktop.
Pretrain transformer using MLM objective from Pandas dataframe
import pandas as pd
from datasets import Dataset
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
def get_tokenize_and_chunk_fn(tokenizer, context_length=128):
def tokenize_and_chunk(texts):
all_input_ids = []
for input_ids in tokenizer(texts["source"], return_token_type_ids=False)[
"input_ids"
]:
all_input_ids.extend(input_ids)
all_input_ids.append(tokenizer.sep_token_id)
chunks = []
for idx in range(0, len(all_input_ids), context_length):
chunks.append(all_input_ids[idx : idx + context_length])
return {"input_ids": chunks}
return tokenize_and_chunk
if __name__ == "__main__":
context_length = 128 # length of chunks
bert_model = "prajjwal1/bert-mini" # adjust for whichever model to use
# df with one column ("source") which contains the text to use
df = pd.read_parquet("train.parquet")
dataset = Dataset.from_pandas(df)
tokenizer = AutoTokenizer.from_pretrained(bert_model)
# apply tokenization and chunk tokenized text into equal lengths
tokenized_dataset = dataset.map(
get_tokenize_and_chunk_fn(tokenizer, context_length=context_length),
batched=True,
remove_columns=["source"],
num_proc=4,
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
model = AutoModelForMaskedLM.from_pretrained(bert_model)
# uses the separator token (can also use eos) to shift accordingly into batches
tokenizer.pad_token = tokenizer.sep_token
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
training_args = TrainingArguments(
output_dir=f"./models/pretrain/{bert_model}",
num_train_epochs=10,
per_device_train_batch_size=64,
save_steps=10000,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset,
)
trainer.train()
trainer.save_model(f"./models/pretrain/{bert_model}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment