Last active
June 15, 2022 19:30
-
-
Save nbroad1881/91fdd2a42ebb9b5cc5384ee8cd0450e9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class MultiSampleDropout(nn.Module): | |
def __init__(self, dropout_probs, problem_type, num_labels) -> None: | |
super().__init__() | |
self.dropouts = [nn.Dropout(p=p) for p in dropout_probs] | |
self.problem_type = problem_type | |
self.num_labels = num_labels | |
def forward(self, hidden_states, linear, labels, loss_fn): | |
logits = [linear(d(hidden_states)) for d in self.dropouts] | |
if self.problem_type == "regression": | |
logits = [l.view(-1) for l in logits] | |
labels = labels.view(-1) | |
elif self.problem_type == "single_label_classification": | |
logits = [l.view(-1, self.num_labels) for l in logits] | |
labels = labels.view(-1) | |
losses = [loss_fn(log, labels) for log in logits] | |
logits = torch.mean(torch.stack(logits, dim=0), dim=0) | |
loss = torch.mean(torch.stack(losses, dim=0), dim=0) | |
return (loss, logits) | |
from transformers import AutoModel, AutoConfig, PreTrainedModel | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
class MultiSampleDropoutModel(PreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
if self.config.problem_type is None: | |
if self.config.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.config.num_labels > 1: | |
self.config.problem_type = "single_label_classification" | |
# no checks for multi-label classification | |
if self.config.problem_type == "regression": | |
self.loss_fct = nn.MSELoss() | |
elif self.config.problem_type == "single_label_classification": | |
self.loss_fct = nn.CrossEntropyLoss() | |
elif self.config.problem_type == "multi_label_classification": | |
self.loss_fct = nn.BCEWithLogitsLoss() | |
self.backbone = AutoModel.from_config(config) | |
self.multisample_dropout = MultiSampleDropout( | |
config.multisample_dropout, self.config.problem_type, self.config.num_labels | |
) | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
self._init_weights(self.classifier) | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
labels=None, | |
token_type_ids=None, | |
**kwargs | |
): | |
token_type_ids = ( | |
{"token_type_ids": token_type_ids} if token_type_ids is not None else {} | |
) | |
outputs = self.backbone( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
**token_type_ids, | |
**kwargs | |
) | |
cls_embedding = outputs[0][:, 0, :] | |
loss = None | |
if labels is not None: | |
loss, logits = self.multisample_dropout( | |
cls_embedding, self.classifier, labels, self.loss_fct | |
) | |
else: | |
logits = self.classifier(cls_embedding) | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def _init_weights(self, module): | |
std = getattr(self.config, "initializer_range", 0.02) | |
"""Initialize the weights""" | |
if isinstance(module, nn.Sequential): | |
for m in module.modules(): | |
self._init_weights(m) | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def get_pretrained(config, model_path): | |
model = MultiSampleDropoutModel(config) | |
if model_path.endswith("pytorch_model.bin"): | |
model.load_state_dict(torch.load(model_path)) | |
else: | |
model.backbone = AutoModel.from_pretrained(model_path, config=config) | |
return model | |
model_path = "microsoft/deberta-v3-xsmall" | |
cfg = AutoConfig.from_pretrained(model_path) | |
cfg.update({"num_labels": 2, "multisample_dropout": [0.5] * 8}) | |
model = get_pretrained(cfg, model_path=model_path) | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
tokens = tokenizer(["This is an example", "this is another"], return_tensors="pt", padding=True) | |
labels = torch.tensor([[1], [0]]) | |
# training with labels | |
output = model(**tokens, labels=labels) | |
print(output) | |
# inference with no labels | |
output = model(**tokens) | |
print(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment