Skip to content

Instantly share code, notes, and snippets.

@amallia
Last active November 25, 2021 18:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amallia/48b3bc994b7982a1236cb4d0eec3f84a to your computer and use it in GitHub Desktop.
Save amallia/48b3bc994b7982a1236cb4d0eec3f84a to your computer and use it in GitHub Desktop.
from transformers import *
import torch
from torch.nn.functional import cross_entropy
from transformers import AdamW
class MonoBERT(BertPreTrainedModel):
def __init__(self, config):
config.num_labels = 1
super(MonoBERT, self).__init__(config)
self.bert = BertForSequenceClassification(config)
self.init_weights()
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.bert(input_ids, attention_mask, token_type_ids)
logits = outputs[0]
return logits
model = MonoBERT.from_pretrained("bert-base-uncased")
optimizer = AdamW(model.parameters(), lr=1e-5, eps=1e-8)
text="aaa[SEP]AAA"
text2="aaa[SEP]BBB"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
encoded = tokenizer.encode_plus(text, return_tensors="pt")
output = model.forward(**encoded)
encoded2 = tokenizer.encode_plus(text2, return_tensors="pt")
output2 = model.forward(**encoded2)
print(output)
print(output2)
labels = torch.zeros(1, dtype=torch.long)
print(torch.stack((output.squeeze(1), output2.squeeze(1)), dim=1))
loss = cross_entropy(torch.stack((output.squeeze(1), output2.squeeze(1)), dim=1), labels)
print(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
output = model.forward(**encoded)
output2 = model.forward(**encoded2)
print(output)
print(output2)
class MonoBERT(BertPreTrainedModel):
def __init__(self, config):
super(MonoBERT, self).__init__(config)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.bert = BertForSequenceClassification(config)
self.init_weights()
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.model(input_ids, attention_mask, token_type_ids)
logits = outputs[0]
return logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment