Created
July 6, 2023 14:53
-
-
Save ficstamas/263435c924abdd7f742d9925ab12b0d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import ElectraForMaskedLM, ElectraForPreTraining | |
from transformers.models.electra.modeling_electra import ElectraForPreTrainingOutput | |
from torch import Tensor | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import torch | |
class ElectraModel(nn.Module): | |
def __init__(self, | |
generator: str = 'google/electra-base-discriminator', | |
discriminator: str = 'google/electra-base-generator', | |
loss_weights=(1, 50)): | |
""" | |
Electra model for pre-training | |
:param generator: Name of the generator model. Can be any model with MLM configuration | |
:param discriminator: Name of the discriminator model. Can be any Model with a binary classification head at the end | |
:param loss_weights: (generator weight, discriminator weight) the default values are the suggested values from the original paper | |
""" | |
super(ElectraModel, self).__init__() | |
self.generator: ElectraForMaskedLM = ElectraForMaskedLM.from_pretrained(generator) | |
self.discriminator: ElectraForPreTraining = ElectraForPreTraining.from_pretrained(discriminator) | |
# share embedding between the generator and discriminator | |
self.discriminator.electra.embeddings = self.generator.electra.embeddings | |
# usual routine to share the embedding weights with the lm head | |
self.generator.generator_lm_head.weight = self.generator.electra.embeddings.word_embeddings.weight | |
self.generator_loss_fct = nn.CrossEntropyLoss() | |
self.gumbel_dist = torch.distributions.gumbel.Gumbel(0., 1.) | |
self.loss_weights = loss_weights | |
def forward(self, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor, labels: Tensor, mlm_mask: Tensor): | |
""" | |
:param input_ids: Input IDs already modified by MLM | |
:param attention_mask: Attention Mask | |
:param token_type_ids: Token Type IDS | |
:param labels: Original Input IDs | |
:param mlm_mask: A binary mask representing the masked tokens | |
:return: | |
""" | |
# generator | |
output = self.generator(input_ids, attention_mask, token_type_ids, labels=labels) | |
generator_logits = output.logits[mlm_mask, :] | |
# sampling | |
with torch.no_grad(): | |
# sample tokens | |
generator_tokens = self.sample(generator_logits) | |
# substitute new sampled token ids in the place of mlm tokens | |
discriminator_input = input_ids.clone() | |
discriminator_input[mlm_mask] = generator_tokens | |
# labels | |
is_replaced = mlm_mask.clone() | |
is_replaced[mlm_mask] = (generator_tokens != labels[mlm_mask]) | |
# discriminator | |
output_discriminator = self.discriminator(discriminator_input, attention_mask, token_type_ids, labels=is_replaced) | |
# loss | |
generator_loss = self.generator_loss_fct(generator_logits[is_replaced[mlm_mask], :], labels[is_replaced]) | |
loss = output_discriminator.loss*self.loss_weights[1] + generator_loss*self.loss_weights[0] | |
return ElectraForPreTrainingOutput( | |
loss=loss, | |
logits=output_discriminator.logits, | |
hidden_states=output_discriminator.hidden_states, | |
attentions=output_discriminator.attentions, | |
) | |
def sample(self, logits, sampling="fp32_gumbel"): | |
"""Reimplement gumbel softmax cuz there is a bug in torch.nn.functional.gumbel_softmax when fp16 | |
(https://github.com/pytorch/pytorch/issues/41663). | |
Gumbel softmax is equal to what official ELECTRA code do, standard gumbel dist. | |
= -ln(-ln(standard uniform dist.)) | |
""" | |
if sampling == 'fp32_gumbel': | |
gumbel = self.gumbel_dist.sample(logits.shape).to(logits.device) | |
return (logits.float() + gumbel).argmax(dim=-1) | |
elif sampling == 'fp16_gumbel': # 5.06 ms | |
gumbel = self.gumbel_dist.sample(logits.shape).to(logits.device) | |
return (logits + gumbel).argmax(dim=-1) | |
elif sampling == 'multinomial': # 2.X ms | |
return torch.multinomial(F.softmax(logits, dim=-1), 1).squeeze() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment