This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# converting list of class weights to a tensor | |
weights= torch.tensor(class_weights,dtype=torch.float) | |
# push to GPU | |
weights = weights.to(device) | |
# define the loss function | |
cross_entropy = nn.NLLLoss(weight=weights) | |
# number of training epochs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.utils.class_weight import compute_class_weight | |
#compute the class weights | |
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels) | |
print("Class Weights:",class_weights) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# optimizer from hugging face transformers | |
from transformers import AdamW | |
# define the optimizer | |
optimizer = AdamW(model.parameters(), | |
lr = 1e-5) # learning rate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pass the pre-trained BERT to our define architecture | |
model = BERT_Arch(bert) | |
# push the model to GPU | |
model = model.to(device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BERT_Arch(nn.Module): | |
def __init__(self, bert): | |
super(BERT_Arch, self).__init__() | |
self.bert = bert | |
# dropout layer | |
self.dropout = nn.Dropout(0.1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# freeze all the parameters | |
for param in bert.parameters(): | |
param.requires_grad = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler | |
#define a batch size | |
batch_size = 32 | |
# wrap tensors | |
train_data = TensorDataset(train_seq, train_mask, train_y) | |
# sampler for sampling the data during training | |
train_sampler = RandomSampler(train_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## convert lists to tensors | |
train_seq = torch.tensor(tokens_train['input_ids']) | |
train_mask = torch.tensor(tokens_train['attention_mask']) | |
train_y = torch.tensor(train_labels.tolist()) | |
val_seq = torch.tensor(tokens_val['input_ids']) | |
val_mask = torch.tensor(tokens_val['attention_mask']) | |
val_y = torch.tensor(val_labels.tolist()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# get length of all the messages in the train set | |
seq_len = [len(i.split()) for i in train_text] | |
pd.Series(seq_len).hist(bins = 30) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# tokenize and encode sequences in the training set | |
tokens_train = tokenizer.batch_encode_plus( | |
train_text.tolist(), | |
max_length = 25, | |
pad_to_max_length=True, | |
truncation=True | |
) | |
# tokenize and encode sequences in the validation set | |
tokens_val = tokenizer.batch_encode_plus( |