Skip to content

Instantly share code, notes, and snippets.

def train(net, criterion, opti, train_loader, val_loader, args):
for ep in range(args.max_eps):
for it, (seq, attn_masks, labels) in enumerate(train_loader):
#Clear gradients
opti.zero_grad()
#Converting these to cuda tensors
seq, attn_masks, labels = seq.cuda(args.gpu), attn_masks.cuda(args.gpu), labels.cuda(args.gpu)
import torch.nn as nn
import torch.optim as optim
criterion = nn.BCEWithLogitsLoss()
opti = optim.Adam(net.parameters(), lr = 2e-5)
net = SentimentClassifier(freeze_bert = True)
from torch.utils.data import DataLoader
#Creating instances of training and validation set
train_set = SSTDataset(filename = 'data/SST-2/train.tsv', maxlen = 30)
val_set = SSTDataset(filename = 'data/SST-2/dev.tsv', maxlen = 30)
#Creating intsances of training and validation dataloaders
train_loader = DataLoader(train_set, batch_size = 64, num_workers = 5)
val_loader = DataLoader(val_set, batch_size = 64, num_workers = 5)
import torch
import torch.nn as nn
from transformers import BertModel
class SentimentClassifier(nn.Module):
def __init__(self, freeze_bert = True):
super(SentimentClassifier, self).__init__()
#Instantiating BERT model object
self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
class SSTDataset(Dataset):
def __init__(self, filename, maxlen):
#Store the contents of the file in a pandas dataframe
seg_ids = [0 for _ in range(len(padded_tokens))] #Since we only have a single sequence as input
import torch
from transformers import BertModel, BertTokenizer
#Creating instance of BertModel
bert_model = BertModel.from_pretrained('bert-base-uncased')
#Creating intance of tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#Specifying the max length
T = 12
from transformers import BertModel
bert_model = BertModel.from_pretrained('bert-base-uncased')
# Obtaining indices for each token
sent_ids = tokenizer.convert_tokens_to_ids(padded_tokens)
print(sent_ids)
# Out: [101, 1045, 2428, 5632, 2023, 3185, 1037, 2843, 1012, 102, 0, 0]