Skip to content

Instantly share code, notes, and snippets.

@nbroad1881
Last active April 11, 2022 21:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nbroad1881/0f7ee2ac87e70fe6ce62d6f4060f4324 to your computer and use it in GitHub Desktop.
Save nbroad1881/0f7ee2ac87e70fe6ce62d6f4060f4324 to your computer and use it in GitHub Desktop.
from torch import nn
from transformers import AutoModel
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.model = AutoModel.from_pretrained(...)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.2)
self.dropout3 = nn.Dropout(0.3)
self.dropout4 = nn.Dropout(0.4)
self.dropout5 = nn.Dropout(0.5)
self.output = nn.Linear(config.hidden_size, num_outputs)
self.loss_fn = ...
def forward(self, input_ids, attention_mask, labels):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
# When using for inference, remove everything between the # ---
# and replace with `logits = self.output(output[0])`
# ---
output = self.dropout(output[0])
logits1 = self.output(self.dropout1(output))
logits2 = self.output(self.dropout2(output))
logits3 = self.output(self.dropout3(output))
logits4 = self.output(self.dropout4(output))
logits5 = self.output(self.dropout5(output))
logits = (logits1 + logits2 + logits3 + logits4 + logits5) / 5
loss1 = self.loss_fn(logits1, labels)
loss2 = self.loss_fn(logits2, labels)
loss3 = self.loss_fn(logits3, labels)
loss4 = self.loss_fn(logits4, labels)
loss5 = self.loss_fn(logits5, labels)
loss = (loss1 + loss2 + loss3 + loss4 + loss5) / 5
# ---
return logits, loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment