Skip to content

Instantly share code, notes, and snippets.

@sergiolucero
Created April 7, 2021 15:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sergiolucero/670aac581e38eeaa8557f204a5473deb to your computer and use it in GitHub Desktop.
Save sergiolucero/670aac581e38eeaa8557f204a5473deb to your computer and use it in GitHub Desktop.
Torch and BERT
from tqdm import tqdm
import os
import pickle
import pandas as pd
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
import wandb
os.environ['WANDB_API_KEY']='f434112dd0022d671e23cac86d785a2d32f916b5'
wandb.init(project="torch-rgb-1", entity="sergiolucero")
pd.set_option('max_colwidth',300)
df = pd.read_csv('../DATA/ARRIENDOS/sentsDocRGB_body.csv')
stops = pickle.load(open('stopwords.pk','rb'))
print(f'LOADED {len(stops)} STOPWORDS')
#print('BEFORE:', df['head'].apply(lambda h:
df['head'] = df['head'].apply(lambda h: ' '.join([w for w in h.split() if w not in stops]))
#print(df.iloc[0].text)
#print('SHOULD I remove stopwords??')
# df['text'] = df.text.apply(lambda t: t.replace('\t','').replace('\n',' ').replace('\r',' '))
df['text'] = df['head'] # removing verdict!
print('nSentences:', len(df))
label_dict = {'acoge': 0, 'mixed': 1, 'rechaza': 2}
df['label'] = df.fallo.replace(label_dict)
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(df.index.values,
df.label.values,
test_size=0.15,
stratify=df.label.values)
df['data_type'] = ['not_set']*df.shape[0]
df.loc[X_train, 'data_type'] = 'train'
df.loc[X_val, 'data_type'] = 'val'
df.groupby(['fallo', 'label', 'data_type']).count()
BERT_MODEL = 'dccuchile/bert-base-spanish-wwm-cased'
#BERT_MODEL = 'dccuchile/bert-base-spanish-wwm-uncased'
#BERT_MODEL = 'dccuchile/bert-base-multilingual-cased'
MAX_LENGTH = 480 #384 OK! # 512 tops?
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True)
encoded_data_train = tokenizer.batch_encode_plus(
df[df.data_type=='train'].text.values,
add_special_tokens=True, return_attention_mask=True,
pad_to_max_length=True, max_length=MAX_LENGTH,
truncation=True, return_tensors='pt'
)
encoded_data_val = tokenizer.batch_encode_plus(
df[df.data_type=='val'].text.values,
add_special_tokens=True, return_attention_mask=True,
pad_to_max_length=True, max_length=MAX_LENGTH,
truncation=True, return_tensors='pt'
)
input_ids_train = encoded_data_train['input_ids']
attention_masks_train = encoded_data_train['attention_mask']
labels_train = torch.tensor(df[df.data_type=='train'].label.values)
input_ids_val = encoded_data_val['input_ids']
attention_masks_val = encoded_data_val['attention_mask']
labels_val = torch.tensor(df[df.data_type=='val'].label.values)
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)
model = BertForSequenceClassification.from_pretrained(BERT_MODEL,
num_labels=len(label_dict),
output_attentions=False,
output_hidden_states=False)
batch_size = 32 # check! was 3
dataloader_train = DataLoader(dataset_train,
sampler=RandomSampler(dataset_train),
batch_size=batch_size)
dataloader_validation = DataLoader(dataset_val,
sampler=SequentialSampler(dataset_val),
batch_size=batch_size)
optimizer = AdamW(model.parameters(),
lr=1e-5,
eps=1e-8)
epochs = 50
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=0,
num_training_steps=len(dataloader_train)*epochs)
# In[26]:
from sklearn.metrics import f1_score
def f1_score_func(preds, labels):
preds_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return f1_score(labels_flat, preds_flat, average='weighted')
def accuracy_per_class(preds, labels):
label_dict_inverse = {v: k for k, v in label_dict.items()}
preds_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
for label in np.unique(labels_flat):
y_preds = preds_flat[labels_flat==label]
y_true = labels_flat[labels_flat==label]
print(f'Class: {label_dict_inverse[label]}')
print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')
import random
import numpy as np
# seed_val = 17
# random.seed(seed_val)
# np.random.seed(seed_val)
# torch.manual_seed(seed_val)
# torch.cuda.manual_seed_all(seed_val)
device='cpu'
# device='cuda'
def evaluate(dataloader_val):
model.eval()
loss_val_total = 0
predictions, true_vals = [], []
for batch in dataloader_val:
batch = tuple(b.to(device) for b in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'labels': batch[2],
}
with torch.no_grad():
outputs = model(**inputs)
loss = outputs[0]
logits = outputs[1]
loss_val_total += loss.item()
logits = logits.detach().cpu().numpy()
label_ids = inputs['labels'].cpu().numpy()
predictions.append(logits)
true_vals.append(label_ids)
loss_val_avg = loss_val_total/len(dataloader_val)
predictions = np.concatenate(predictions, axis=0)
true_vals = np.concatenate(true_vals, axis=0)
return loss_val_avg, predictions, true_vals
for epoch in tqdm(range(1, epochs+1)):
model.train()
loss_train_total = 0
progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
for batch in progress_bar:
model.zero_grad()
batch = tuple(b.to(device) for b in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'labels': batch[2],
}
outputs = model(**inputs)
loss = outputs[0]
loss_train_total += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
tqdm.write(f'\nEpoch {epoch}')
loss_train_avg = loss_train_total/len(dataloader_train)
tqdm.write(f'Training loss: {loss_train_avg}')
val_loss, predictions, true_vals = evaluate(dataloader_validation)
pickle.dump({'epoch': epoch, 'preds':predictions,'trues':true_vals}, open(f'pred_true_{epoch}.pk','wb'))
val_f1 = f1_score_func(predictions, true_vals)
tqdm.write(f'Validation loss: {val_loss}')
tqdm.write(f'F1 Score (Weighted): {val_f1}')
wandb.log({"Loss": val_loss, "F1": val_f1, 'Train': loss_train_avg})
print('DONE')
torch.save(model.state_dict(), f'data_volume/finetuned_BERT_{MAX_LENGTH}_epoch_{epoch}.model')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment