Skip to content

Instantly share code, notes, and snippets.

@ficstamas
Created July 6, 2023 14:53
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ficstamas/263435c924abdd7f742d9925ab12b0d1 to your computer and use it in GitHub Desktop.
Save ficstamas/263435c924abdd7f742d9925ab12b0d1 to your computer and use it in GitHub Desktop.
from transformers import ElectraForMaskedLM, ElectraForPreTraining
from transformers.models.electra.modeling_electra import ElectraForPreTrainingOutput
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
import torch
class ElectraModel(nn.Module):
def __init__(self,
generator: str = 'google/electra-base-discriminator',
discriminator: str = 'google/electra-base-generator',
loss_weights=(1, 50)):
"""
Electra model for pre-training
:param generator: Name of the generator model. Can be any model with MLM configuration
:param discriminator: Name of the discriminator model. Can be any Model with a binary classification head at the end
:param loss_weights: (generator weight, discriminator weight) the default values are the suggested values from the original paper
"""
super(ElectraModel, self).__init__()
self.generator: ElectraForMaskedLM = ElectraForMaskedLM.from_pretrained(generator)
self.discriminator: ElectraForPreTraining = ElectraForPreTraining.from_pretrained(discriminator)
# share embedding between the generator and discriminator
self.discriminator.electra.embeddings = self.generator.electra.embeddings
# usual routine to share the embedding weights with the lm head
self.generator.generator_lm_head.weight = self.generator.electra.embeddings.word_embeddings.weight
self.generator_loss_fct = nn.CrossEntropyLoss()
self.gumbel_dist = torch.distributions.gumbel.Gumbel(0., 1.)
self.loss_weights = loss_weights
def forward(self, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor, labels: Tensor, mlm_mask: Tensor):
"""
:param input_ids: Input IDs already modified by MLM
:param attention_mask: Attention Mask
:param token_type_ids: Token Type IDS
:param labels: Original Input IDs
:param mlm_mask: A binary mask representing the masked tokens
:return:
"""
# generator
output = self.generator(input_ids, attention_mask, token_type_ids, labels=labels)
generator_logits = output.logits[mlm_mask, :]
# sampling
with torch.no_grad():
# sample tokens
generator_tokens = self.sample(generator_logits)
# substitute new sampled token ids in the place of mlm tokens
discriminator_input = input_ids.clone()
discriminator_input[mlm_mask] = generator_tokens
# labels
is_replaced = mlm_mask.clone()
is_replaced[mlm_mask] = (generator_tokens != labels[mlm_mask])
# discriminator
output_discriminator = self.discriminator(discriminator_input, attention_mask, token_type_ids, labels=is_replaced)
# loss
generator_loss = self.generator_loss_fct(generator_logits[is_replaced[mlm_mask], :], labels[is_replaced])
loss = output_discriminator.loss*self.loss_weights[1] + generator_loss*self.loss_weights[0]
return ElectraForPreTrainingOutput(
loss=loss,
logits=output_discriminator.logits,
hidden_states=output_discriminator.hidden_states,
attentions=output_discriminator.attentions,
)
def sample(self, logits, sampling="fp32_gumbel"):
"""Reimplement gumbel softmax cuz there is a bug in torch.nn.functional.gumbel_softmax when fp16
(https://github.com/pytorch/pytorch/issues/41663).
Gumbel softmax is equal to what official ELECTRA code do, standard gumbel dist.
= -ln(-ln(standard uniform dist.))
"""
if sampling == 'fp32_gumbel':
gumbel = self.gumbel_dist.sample(logits.shape).to(logits.device)
return (logits.float() + gumbel).argmax(dim=-1)
elif sampling == 'fp16_gumbel': # 5.06 ms
gumbel = self.gumbel_dist.sample(logits.shape).to(logits.device)
return (logits + gumbel).argmax(dim=-1)
elif sampling == 'multinomial': # 2.X ms
return torch.multinomial(F.softmax(logits, dim=-1), 1).squeeze()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment