Created
April 7, 2021 15:00
-
-
Save sergiolucero/670aac581e38eeaa8557f204a5473deb to your computer and use it in GitHub Desktop.
Torch and BERT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
import os | |
import pickle | |
import pandas as pd | |
import torch | |
from torch.utils.data import TensorDataset | |
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler | |
from transformers import BertTokenizer, BertForSequenceClassification | |
from transformers import AdamW, get_linear_schedule_with_warmup | |
import wandb | |
os.environ['WANDB_API_KEY']='f434112dd0022d671e23cac86d785a2d32f916b5' | |
wandb.init(project="torch-rgb-1", entity="sergiolucero") | |
pd.set_option('max_colwidth',300) | |
df = pd.read_csv('../DATA/ARRIENDOS/sentsDocRGB_body.csv') | |
stops = pickle.load(open('stopwords.pk','rb')) | |
print(f'LOADED {len(stops)} STOPWORDS') | |
#print('BEFORE:', df['head'].apply(lambda h: | |
df['head'] = df['head'].apply(lambda h: ' '.join([w for w in h.split() if w not in stops])) | |
#print(df.iloc[0].text) | |
#print('SHOULD I remove stopwords??') | |
# df['text'] = df.text.apply(lambda t: t.replace('\t','').replace('\n',' ').replace('\r',' ')) | |
df['text'] = df['head'] # removing verdict! | |
print('nSentences:', len(df)) | |
label_dict = {'acoge': 0, 'mixed': 1, 'rechaza': 2} | |
df['label'] = df.fallo.replace(label_dict) | |
from sklearn.model_selection import train_test_split | |
X_train, X_val, y_train, y_val = train_test_split(df.index.values, | |
df.label.values, | |
test_size=0.15, | |
stratify=df.label.values) | |
df['data_type'] = ['not_set']*df.shape[0] | |
df.loc[X_train, 'data_type'] = 'train' | |
df.loc[X_val, 'data_type'] = 'val' | |
df.groupby(['fallo', 'label', 'data_type']).count() | |
BERT_MODEL = 'dccuchile/bert-base-spanish-wwm-cased' | |
#BERT_MODEL = 'dccuchile/bert-base-spanish-wwm-uncased' | |
#BERT_MODEL = 'dccuchile/bert-base-multilingual-cased' | |
MAX_LENGTH = 480 #384 OK! # 512 tops? | |
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True) | |
encoded_data_train = tokenizer.batch_encode_plus( | |
df[df.data_type=='train'].text.values, | |
add_special_tokens=True, return_attention_mask=True, | |
pad_to_max_length=True, max_length=MAX_LENGTH, | |
truncation=True, return_tensors='pt' | |
) | |
encoded_data_val = tokenizer.batch_encode_plus( | |
df[df.data_type=='val'].text.values, | |
add_special_tokens=True, return_attention_mask=True, | |
pad_to_max_length=True, max_length=MAX_LENGTH, | |
truncation=True, return_tensors='pt' | |
) | |
input_ids_train = encoded_data_train['input_ids'] | |
attention_masks_train = encoded_data_train['attention_mask'] | |
labels_train = torch.tensor(df[df.data_type=='train'].label.values) | |
input_ids_val = encoded_data_val['input_ids'] | |
attention_masks_val = encoded_data_val['attention_mask'] | |
labels_val = torch.tensor(df[df.data_type=='val'].label.values) | |
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train) | |
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val) | |
model = BertForSequenceClassification.from_pretrained(BERT_MODEL, | |
num_labels=len(label_dict), | |
output_attentions=False, | |
output_hidden_states=False) | |
batch_size = 32 # check! was 3 | |
dataloader_train = DataLoader(dataset_train, | |
sampler=RandomSampler(dataset_train), | |
batch_size=batch_size) | |
dataloader_validation = DataLoader(dataset_val, | |
sampler=SequentialSampler(dataset_val), | |
batch_size=batch_size) | |
optimizer = AdamW(model.parameters(), | |
lr=1e-5, | |
eps=1e-8) | |
epochs = 50 | |
scheduler = get_linear_schedule_with_warmup(optimizer, | |
num_warmup_steps=0, | |
num_training_steps=len(dataloader_train)*epochs) | |
# In[26]: | |
from sklearn.metrics import f1_score | |
def f1_score_func(preds, labels): | |
preds_flat = np.argmax(preds, axis=1).flatten() | |
labels_flat = labels.flatten() | |
return f1_score(labels_flat, preds_flat, average='weighted') | |
def accuracy_per_class(preds, labels): | |
label_dict_inverse = {v: k for k, v in label_dict.items()} | |
preds_flat = np.argmax(preds, axis=1).flatten() | |
labels_flat = labels.flatten() | |
for label in np.unique(labels_flat): | |
y_preds = preds_flat[labels_flat==label] | |
y_true = labels_flat[labels_flat==label] | |
print(f'Class: {label_dict_inverse[label]}') | |
print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n') | |
import random | |
import numpy as np | |
# seed_val = 17 | |
# random.seed(seed_val) | |
# np.random.seed(seed_val) | |
# torch.manual_seed(seed_val) | |
# torch.cuda.manual_seed_all(seed_val) | |
device='cpu' | |
# device='cuda' | |
def evaluate(dataloader_val): | |
model.eval() | |
loss_val_total = 0 | |
predictions, true_vals = [], [] | |
for batch in dataloader_val: | |
batch = tuple(b.to(device) for b in batch) | |
inputs = {'input_ids': batch[0], | |
'attention_mask': batch[1], | |
'labels': batch[2], | |
} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
loss = outputs[0] | |
logits = outputs[1] | |
loss_val_total += loss.item() | |
logits = logits.detach().cpu().numpy() | |
label_ids = inputs['labels'].cpu().numpy() | |
predictions.append(logits) | |
true_vals.append(label_ids) | |
loss_val_avg = loss_val_total/len(dataloader_val) | |
predictions = np.concatenate(predictions, axis=0) | |
true_vals = np.concatenate(true_vals, axis=0) | |
return loss_val_avg, predictions, true_vals | |
for epoch in tqdm(range(1, epochs+1)): | |
model.train() | |
loss_train_total = 0 | |
progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False) | |
for batch in progress_bar: | |
model.zero_grad() | |
batch = tuple(b.to(device) for b in batch) | |
inputs = {'input_ids': batch[0], | |
'attention_mask': batch[1], | |
'labels': batch[2], | |
} | |
outputs = model(**inputs) | |
loss = outputs[0] | |
loss_train_total += loss.item() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
scheduler.step() | |
progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))}) | |
tqdm.write(f'\nEpoch {epoch}') | |
loss_train_avg = loss_train_total/len(dataloader_train) | |
tqdm.write(f'Training loss: {loss_train_avg}') | |
val_loss, predictions, true_vals = evaluate(dataloader_validation) | |
pickle.dump({'epoch': epoch, 'preds':predictions,'trues':true_vals}, open(f'pred_true_{epoch}.pk','wb')) | |
val_f1 = f1_score_func(predictions, true_vals) | |
tqdm.write(f'Validation loss: {val_loss}') | |
tqdm.write(f'F1 Score (Weighted): {val_f1}') | |
wandb.log({"Loss": val_loss, "F1": val_f1, 'Train': loss_train_avg}) | |
print('DONE') | |
torch.save(model.state_dict(), f'data_volume/finetuned_BERT_{MAX_LENGTH}_epoch_{epoch}.model') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment