Skip to content

Instantly share code, notes, and snippets.

@Muhammad4hmed
Created June 30, 2021 13:35
Show Gist options
  • Save Muhammad4hmed/c5e252134a13fe997806d760000da7b4 to your computer and use it in GitHub Desktop.
Save Muhammad4hmed/c5e252134a13fe997806d760000da7b4 to your computer and use it in GitHub Desktop.
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
from transformers import (AutoModel,AutoModelForMaskedLM,
AutoTokenizer, LineByLineTextDataset,
DataCollatorForLanguageModeling,
Trainer, TrainingArguments)
train_data = pd.read_csv('../input/commonlitreadabilityprize/train.csv')
test_data = pd.read_csv('../input/commonlitreadabilityprize/test.csv')
data = pd.concat([train_data,test_data])
data['excerpt'] = data['excerpt'].apply(lambda x: x.replace('\n',''))
text = '\n'.join(data.excerpt.tolist())
with open('text.txt','w') as f:
f.write(text)
model_name = 'roberta-base'
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained('./clrp_roberta_base');
train_dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path="text.txt", #mention train text file here
block_size=256)
valid_dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path="text.txt", #mention valid text file here
block_size=256)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
training_args = TrainingArguments(
output_dir="./clrp_roberta_base_chk", #select model path for checkpoint
overwrite_output_dir=True,
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
evaluation_strategy= 'steps',
save_total_limit=2,
eval_steps=200,
metric_for_best_model='eval_loss',
greater_is_better=False,
load_best_model_at_end =True,
prediction_loss_only=True,
report_to = "none")
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=valid_dataset)
trainer.train()
trainer.save_model(f'./clrp_roberta_base')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment