Last active
November 25, 2021 18:14
-
-
Save amallia/48b3bc994b7982a1236cb4d0eec3f84a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import * | |
import torch | |
from torch.nn.functional import cross_entropy | |
from transformers import AdamW | |
class MonoBERT(BertPreTrainedModel): | |
def __init__(self, config): | |
config.num_labels = 1 | |
super(MonoBERT, self).__init__(config) | |
self.bert = BertForSequenceClassification(config) | |
self.init_weights() | |
def forward(self, input_ids, attention_mask, token_type_ids): | |
outputs = self.bert(input_ids, attention_mask, token_type_ids) | |
logits = outputs[0] | |
return logits | |
model = MonoBERT.from_pretrained("bert-base-uncased") | |
optimizer = AdamW(model.parameters(), lr=1e-5, eps=1e-8) | |
text="aaa[SEP]AAA" | |
text2="aaa[SEP]BBB" | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
encoded = tokenizer.encode_plus(text, return_tensors="pt") | |
output = model.forward(**encoded) | |
encoded2 = tokenizer.encode_plus(text2, return_tensors="pt") | |
output2 = model.forward(**encoded2) | |
print(output) | |
print(output2) | |
labels = torch.zeros(1, dtype=torch.long) | |
print(torch.stack((output.squeeze(1), output2.squeeze(1)), dim=1)) | |
loss = cross_entropy(torch.stack((output.squeeze(1), output2.squeeze(1)), dim=1), labels) | |
print(loss) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
output = model.forward(**encoded) | |
output2 = model.forward(**encoded2) | |
print(output) | |
print(output2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MonoBERT(BertPreTrainedModel): | |
def __init__(self, config): | |
super(MonoBERT, self).__init__(config) | |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
self.bert = BertForSequenceClassification(config) | |
self.init_weights() | |
def forward(self, input_ids, attention_mask, token_type_ids): | |
outputs = self.model(input_ids, attention_mask, token_type_ids) | |
logits = outputs[0] | |
return logits | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment