Last active
June 15, 2020 12:41
-
-
Save sgugger/edc345943c92b155e0d73ef7a1897c21 to your computer and use it in GitHub Desktop.
Auto model for multiple choice
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, AutoConfig, AutoModel | |
class ModelWithMultipleChoiceHead(PreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
base_model = AutoModel.from_config(config) | |
self.base_model_prefix = base_model.base_model_prefix | |
self.config_class = base_model.config_class | |
self.load_tf_weights = base_model.load_tf_weights | |
setattr(self, self.base_model_prefix, base_model) | |
#TODO: use something more general like the SequenceSummary, also arg can be config.ndim. | |
if not base_model._has_pool: | |
self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size) | |
dropout_p = config.hidden_dropout_prob if hasattr(config, 'hidden_dropout_prob') else config.summary_last_dropout | |
self.dropout = nn.Dropout(dropout_p) | |
self.classifier = nn.Linear(config.hidden_size, 1) | |
self.init_weights() | |
def _init_weights(self, module): | |
self.base_model._init_weights(module) | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
**kwargs | |
): | |
r""" | |
Args: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using :class:`transformers.BertTokenizer`. | |
See :func:`transformers.PreTrainedTokenizer.encode` and | |
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. | |
`What are input IDs? <../glossary.html#input-ids>`__ | |
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): | |
Mask to avoid performing attention on padding token indices. | |
Mask values selected in ``[0, 1]``: | |
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. | |
`What are attention masks? <../glossary.html#attention-mask>`__ | |
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): | |
Segment token indices to indicate first and second portions of the inputs. | |
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` | |
corresponds to a `sentence B` token | |
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): | |
Mask to nullify selected heads of the self-attention modules. | |
Mask values selected in ``[0, 1]``: | |
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. | |
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): | |
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |
This is useful if you want more control over how to convert `input_ids` indices into associated vectors | |
than the model's internal embedding lookup matrix. | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | |
Labels for computing the multiple choice classification loss. | |
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension | |
of the input tensors. (see `input_ids` above) | |
output_attentions (:obj:`bool`, `optional`, defaults to `:obj:`None`): | |
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. | |
model_kwargs: | |
Specific arguments relative to your model, see its documentation for the possibilities. | |
Returns: | |
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: | |
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): | |
Classification loss. | |
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): | |
`num_choices` is the second dimension of the input tensors. (see `input_ids` above). | |
Classification scores (before SoftMax). | |
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape | |
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
heads. | |
Examples:: | |
from transformers import AutoTokenizer, ModelWithMultipleChoiceHead | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
model = ModelWithMultipleChoiceHead.from_pretrained('bert-base-uncased') | |
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." | |
choice0 = "It is eaten with a fork and a knife." | |
choice1 = "It is eaten while held in the hand." | |
labels = torch.tensor(0) # choice0 is correct (according to Wikipedia ;)) | |
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True) | |
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1 | |
# the linear classifier still needs to be trained | |
loss, logits = outputs[:2] | |
""" | |
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] | |
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None | |
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None | |
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None | |
for key in ["position_idx", "input_mask", "global_attention_mask"]: | |
if key in kwargs and kwargs[key] is not None: | |
kwargs[key] = kwargs[key].view(-1, kwargs[key].size(-1)) | |
flat_inputs_embeds = ( | |
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) | |
if inputs_embeds is not None | |
else None | |
) | |
outputs = self.base_model( | |
flat_input_ids, | |
attention_mask=flat_attention_mask, | |
token_type_ids=flat_token_type_ids, | |
head_mask=head_mask, | |
inputs_embeds=flat_inputs_embeds, | |
output_attentions=output_attentions, | |
**kwargs | |
) | |
if self.base_model._has_pool: | |
pooled_output = outputs[1] | |
hidden_and_attentions = outputs[2:] | |
else: | |
pooled_output = outputs[0][:, 0] | |
pooled_output = self.pre_classifier(pooled_output) | |
pooled_output = nn.ReLU()(pooled_output) | |
hidden_and_attentions = outputs[1:] | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
reshaped_logits = logits.view(-1, num_choices) | |
outputs = (reshaped_logits,) + hidden_and_attentions # add hidden states and attention if they are here | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(reshaped_logits, labels) | |
outputs = (loss,) + outputs | |
return outputs # (loss), reshaped_logits, (hidden_states), (attentions) | |
# from_pretrained relies on the model class instead of AutoConfig, adjusts that. | |
@classmethod | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
config = kwargs.pop("config", None) | |
if config is None: | |
config = AutoConfig.from_pretrained(pretrained_model_name_or_path) | |
return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment