Skip to content

Instantly share code, notes, and snippets.

@sgugger
Last active June 15, 2020 12:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sgugger/edc345943c92b155e0d73ef7a1897c21 to your computer and use it in GitHub Desktop.
Save sgugger/edc345943c92b155e0d73ef7a1897c21 to your computer and use it in GitHub Desktop.
Auto model for multiple choice
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoConfig, AutoModel
class ModelWithMultipleChoiceHead(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
base_model = AutoModel.from_config(config)
self.base_model_prefix = base_model.base_model_prefix
self.config_class = base_model.config_class
self.load_tf_weights = base_model.load_tf_weights
setattr(self, self.base_model_prefix, base_model)
#TODO: use something more general like the SequenceSummary, also arg can be config.ndim.
if not base_model._has_pool:
self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
dropout_p = config.hidden_dropout_prob if hasattr(config, 'hidden_dropout_prob') else config.summary_last_dropout
self.dropout = nn.Dropout(dropout_p)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
def _init_weights(self, module):
self.base_model._init_weights(module)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
**kwargs
):
r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
output_attentions (:obj:`bool`, `optional`, defaults to `:obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
model_kwargs:
Specific arguments relative to your model, see its documentation for the possibilities.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import AutoTokenizer, ModelWithMultipleChoiceHead
import torch
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = ModelWithMultipleChoiceHead.from_pretrained('bert-base-uncased')
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
labels = torch.tensor(0) # choice0 is correct (according to Wikipedia ;))
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
# the linear classifier still needs to be trained
loss, logits = outputs[:2]
"""
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
for key in ["position_idx", "input_mask", "global_attention_mask"]:
if key in kwargs and kwargs[key] is not None:
kwargs[key] = kwargs[key].view(-1, kwargs[key].size(-1))
flat_inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
outputs = self.base_model(
flat_input_ids,
attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
**kwargs
)
if self.base_model._has_pool:
pooled_output = outputs[1]
hidden_and_attentions = outputs[2:]
else:
pooled_output = outputs[0][:, 0]
pooled_output = self.pre_classifier(pooled_output)
pooled_output = nn.ReLU()(pooled_output)
hidden_and_attentions = outputs[1:]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + hidden_and_attentions # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
# from_pretrained relies on the model class instead of AutoConfig, adjusts that.
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
if config is None:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment