Skip to content

Instantly share code, notes, and snippets.

@nbroad1881
Last active June 15, 2022 19:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nbroad1881/91fdd2a42ebb9b5cc5384ee8cd0450e9 to your computer and use it in GitHub Desktop.
Save nbroad1881/91fdd2a42ebb9b5cc5384ee8cd0450e9 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class MultiSampleDropout(nn.Module):
def __init__(self, dropout_probs, problem_type, num_labels) -> None:
super().__init__()
self.dropouts = [nn.Dropout(p=p) for p in dropout_probs]
self.problem_type = problem_type
self.num_labels = num_labels
def forward(self, hidden_states, linear, labels, loss_fn):
logits = [linear(d(hidden_states)) for d in self.dropouts]
if self.problem_type == "regression":
logits = [l.view(-1) for l in logits]
labels = labels.view(-1)
elif self.problem_type == "single_label_classification":
logits = [l.view(-1, self.num_labels) for l in logits]
labels = labels.view(-1)
losses = [loss_fn(log, labels) for log in logits]
logits = torch.mean(torch.stack(logits, dim=0), dim=0)
loss = torch.mean(torch.stack(losses, dim=0), dim=0)
return (loss, logits)
from transformers import AutoModel, AutoConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
class MultiSampleDropoutModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1:
self.config.problem_type = "single_label_classification"
# no checks for multi-label classification
if self.config.problem_type == "regression":
self.loss_fct = nn.MSELoss()
elif self.config.problem_type == "single_label_classification":
self.loss_fct = nn.CrossEntropyLoss()
elif self.config.problem_type == "multi_label_classification":
self.loss_fct = nn.BCEWithLogitsLoss()
self.backbone = AutoModel.from_config(config)
self.multisample_dropout = MultiSampleDropout(
config.multisample_dropout, self.config.problem_type, self.config.num_labels
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._init_weights(self.classifier)
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
token_type_ids=None,
**kwargs
):
token_type_ids = (
{"token_type_ids": token_type_ids} if token_type_ids is not None else {}
)
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
**token_type_ids,
**kwargs
)
cls_embedding = outputs[0][:, 0, :]
loss = None
if labels is not None:
loss, logits = self.multisample_dropout(
cls_embedding, self.classifier, labels, self.loss_fct
)
else:
logits = self.classifier(cls_embedding)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", 0.02)
"""Initialize the weights"""
if isinstance(module, nn.Sequential):
for m in module.modules():
self._init_weights(m)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def get_pretrained(config, model_path):
model = MultiSampleDropoutModel(config)
if model_path.endswith("pytorch_model.bin"):
model.load_state_dict(torch.load(model_path))
else:
model.backbone = AutoModel.from_pretrained(model_path, config=config)
return model
model_path = "microsoft/deberta-v3-xsmall"
cfg = AutoConfig.from_pretrained(model_path)
cfg.update({"num_labels": 2, "multisample_dropout": [0.5] * 8})
model = get_pretrained(cfg, model_path=model_path)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokens = tokenizer(["This is an example", "this is another"], return_tensors="pt", padding=True)
labels = torch.tensor([[1], [0]])
# training with labels
output = model(**tokens, labels=labels)
print(output)
# inference with no labels
output = model(**tokens)
print(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment