Skip to content

Instantly share code, notes, and snippets.

View prateekjoshi565's full-sized avatar
🎯
Focusing

Prateek Joshi prateekjoshi565

🎯
Focusing
View GitHub Profile
# converting list of class weights to a tensor
weights= torch.tensor(class_weights,dtype=torch.float)
# push to GPU
weights = weights.to(device)
# define the loss function
cross_entropy = nn.NLLLoss(weight=weights)
# number of training epochs
from sklearn.utils.class_weight import compute_class_weight
#compute the class weights
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
print("Class Weights:",class_weights)
# optimizer from hugging face transformers
from transformers import AdamW
# define the optimizer
optimizer = AdamW(model.parameters(),
lr = 1e-5) # learning rate
# pass the pre-trained BERT to our define architecture
model = BERT_Arch(bert)
# push the model to GPU
model = model.to(device)
class BERT_Arch(nn.Module):
def __init__(self, bert):
super(BERT_Arch, self).__init__()
self.bert = bert
# dropout layer
self.dropout = nn.Dropout(0.1)
# freeze all the parameters
for param in bert.parameters():
param.requires_grad = False
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
#define a batch size
batch_size = 32
# wrap tensors
train_data = TensorDataset(train_seq, train_mask, train_y)
# sampler for sampling the data during training
train_sampler = RandomSampler(train_data)
## convert lists to tensors
train_seq = torch.tensor(tokens_train['input_ids'])
train_mask = torch.tensor(tokens_train['attention_mask'])
train_y = torch.tensor(train_labels.tolist())
val_seq = torch.tensor(tokens_val['input_ids'])
val_mask = torch.tensor(tokens_val['attention_mask'])
val_y = torch.tensor(val_labels.tolist())
# get length of all the messages in the train set
seq_len = [len(i.split()) for i in train_text]
pd.Series(seq_len).hist(bins = 30)
# tokenize and encode sequences in the training set
tokens_train = tokenizer.batch_encode_plus(
train_text.tolist(),
max_length = 25,
pad_to_max_length=True,
truncation=True
)
# tokenize and encode sequences in the validation set
tokens_val = tokenizer.batch_encode_plus(