Skip to content

Instantly share code, notes, and snippets.

@kabirahuja2431
Created October 9, 2019 14:01
Show Gist options
  • Save kabirahuja2431/93a146cca3034ba4d0159e97d5111645 to your computer and use it in GitHub Desktop.
Save kabirahuja2431/93a146cca3034ba4d0159e97d5111645 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from transformers import BertModel
class SentimentClassifier(nn.Module):
def __init__(self, freeze_bert = True):
super(SentimentClassifier, self).__init__()
#Instantiating BERT model object
self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
#Freeze bert layers
if freeze_bert:
for p in self.bert_layer.parameters():
p.requires_grad = False
#Classification layer
self.cls_layer = nn.Linear(768, 1)
def forward(self, seq, attn_masks):
'''
Inputs:
-seq : Tensor of shape [B, T] containing token ids of sequences
-attn_masks : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens
'''
#Feeding the input to BERT model to obtain contextualized representations
cont_reps, _ = self.bert_layer(seq, attention_mask = attn_masks)
#Obtaining the representation of [CLS] head
cls_rep = cont_reps[:, 0]
#Feeding cls_rep to the classifier layer
logits = self.cls_layer(cls_rep)
return logits
@yishairasowsky
Copy link

@kabirahuja2431 i get an error that maybe you can help me with...
image

@gsm007-data
Copy link

cont_reps = self.bert_layer(seq, attention_mask = attn_masks)['last_hidden_state'] -> This gives the output of last_hidden_state

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment