Created
August 4, 2021 02:37
-
-
Save msaroufim/f0af279af5a0d455870cf62c2a247511 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import transformers | |
from pathlib import Path | |
import os | |
import json | |
import torch | |
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering, | |
AutoModelForTokenClassification, AutoConfig) | |
from transformers import set_seed, AdamW, get_scheduler | |
from datasets import load_dataset | |
from torch.utils.data import DataLoader | |
from tqdm.auto import tqdm | |
os.system("rm -r /home/ubuntu/.cache/huggingface") | |
os.system("rm -r Transformer_model") | |
""" This function, save the checkpoint, config file along with tokenizer config and vocab files | |
of a transformer model of your choice. | |
""" | |
print('Transformers version',transformers.__version__) | |
set_seed(1) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def transformers_model_dowloader(mode,pretrained_model_name,num_labels,do_lower_case,max_length,torchscript): | |
print("Download model and tokenizer", pretrained_model_name) | |
#loading pre-trained model and tokenizer | |
if mode== "sequence_classification": | |
config = AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels,torchscript=torchscript) | |
model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, config=config) | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case) | |
elif mode== "question_answering": | |
config = AutoConfig.from_pretrained(pretrained_model_name,torchscript=torchscript) | |
model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name,config=config) | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case) | |
elif mode== "token_classification": | |
config= AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels,torchscript=torchscript) | |
model = AutoModelForTokenClassification.from_pretrained(pretrained_model_name, config=config) | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case) | |
# NOTE : for demonstration purposes, we do not go through the fine-tune processing here. | |
# A Fine_tunining process based on your needs can be added. | |
# An example of Fine_tuned model has been provided in the README. | |
raw_datasets = load_dataset("imdb") | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], padding="max_length", truncation=True) | |
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
## SETUP DATA_SET | |
# small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000)) | |
# small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000)) | |
# full_train_dataset = tokenized_datasets["train"] | |
# full_eval_dataset = tokenized_datasets["test"] | |
# tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
tokenized_datasets = tokenized_datasets.remove_columns(["text"]) | |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
tokenized_datasets.set_format("torch") | |
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000)) | |
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000)) | |
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8) | |
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8) | |
optimizer = AdamW(model.parameters(), lr=5e-5) | |
NEW_DIR = "./Transformer_model" | |
try: | |
os.mkdir(NEW_DIR) | |
except OSError: | |
print ("Creation of directory %s failed" % NEW_DIR) | |
else: | |
print ("Successfully created directory %s " % NEW_DIR) | |
## Setup Scheduler | |
num_epochs = 3 | |
num_training_steps = num_epochs * len(train_dataloader) | |
lr_scheduler = get_scheduler( | |
"linear", | |
optimizer=optimizer, | |
num_warmup_steps=0, | |
num_training_steps=num_training_steps | |
) | |
model.to(device) | |
# Setup AMP | |
scaler = torch.cuda.amp.GradScaler() | |
progress_bar = tqdm(range(num_training_steps)) | |
model.train() | |
for epoch in range(num_epochs): | |
for batch in train_dataloader: | |
batch = {k: v.to(device) for k, v in batch.items()} | |
with torch.cuda.amp.autocast(): | |
loss, outputs = model(**batch) | |
# loss = outputs.loss | |
# loss.backward() | |
# optimizer.step() | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
# optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
progress_bar.update(1) | |
print("Save model and tokenizer/ Torchscript model based on the setting from setup_config", pretrained_model_name, 'in directory', NEW_DIR) | |
if save_mode == "pretrained": | |
model.save_pretrained(NEW_DIR) | |
tokenizer.save_pretrained(NEW_DIR) | |
elif save_mode == "torchscript": | |
dummy_input = "This is a dummy input for torch jit trace" | |
inputs = tokenizer.encode_plus(dummy_input,max_length = int(max_length),pad_to_max_length = True, add_special_tokens = True, return_tensors = 'pt') | |
input_ids = inputs["input_ids"].to(device) | |
attention_mask = inputs["attention_mask"].to(device) | |
model.to(device).eval() | |
traced_model = torch.jit.trace(model, (input_ids, attention_mask)) | |
torch.jit.save(traced_model,os.path.join(NEW_DIR, "traced_model.pt")) | |
return | |
if __name__== "__main__": | |
dirname = os.path.dirname(__file__) | |
filename = os.path.join(dirname, 'setup_config.json') | |
f = open(filename) | |
settings = json.load(f) | |
mode = settings["mode"] | |
model_name = settings["model_name"] | |
num_labels = int(settings["num_labels"]) | |
do_lower_case = settings["do_lower_case"] | |
max_length = settings["max_length"] | |
save_mode = settings["save_mode"] | |
if save_mode == "torchscript": | |
torchscript = True | |
else: | |
torchscript = False | |
transformers_model_dowloader(mode,model_name, num_labels,do_lower_case, max_length, torchscript) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment