Skip to content

Instantly share code, notes, and snippets.

@w32zhong
Created August 20, 2022 04:42
Show Gist options
  • Save w32zhong/2533d45be4fdbbe62894dd3a20442fcd to your computer and use it in GitHub Desktop.
Save w32zhong/2533d45be4fdbbe62894dd3a20442fcd to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import transformers
from transformers import BertLayer
from transformers.models.bert.modeling_bert import BertOnlyMLMHead
from transformers import BertTokenizer
from transformers import BertForPreTraining
class CondensorPretraining(nn.Module):
def __init__(self, n_dec_layers=2, skip_from=0):
super().__init__()
# pretrained encoder
self.enc = BertForPreTraining.from_pretrained(
'bert-base-uncased',
tie_word_embeddings=True
)
config = self.enc.config
# new decoder
self.dec = nn.ModuleList(
[BertLayer(config) for _ in range(n_dec_layers)]
)
self.dec_mlm_head = BertOnlyMLMHead(config)
# load as much as good initial weights
self.dec.apply(self.enc._init_weights)
self.dec_mlm_head.apply(self.enc._init_weights)
# save parameter
self.skip_from = skip_from
def forward(self, inputs, mode='condensor', cot_cls_hiddens=None):
assert mode in ['condensor', 'cot-mae-enc', 'cot-mae-dec']
enc_output = self.enc(
**inputs,
output_hidden_states=True,
return_dict=True # output in a dict structure
)
#print(enc_output.keys())
# all_hidden_states == all_hidden_states + (hidden_states,)
enc_hidden_states = enc_output.hidden_states # [13, B, N, 768]
# where B is batch size and N is the sequence length.
# the enc_hidden_states contains a 13-element tuple where
# the 1st one is the initial input embeddings.
cls_hiddens = enc_hidden_states[-1][:, :1]
skip_hiddens = enc_hidden_states[self.skip_from][:, 1:]
#print(cls_hiddens.shape) # [B, 1, 768]
#print(skip_hiddens.shape) # [B, N-1, 768]
if mode == 'cot-mae-enc':
return enc_output.prediction_logits, cls_hiddens
elif mode == 'cot-mae-dec':
hiddens = torch.cat([cot_cls_hiddens, skip_hiddens], dim=1)
elif mode == 'condensor':
hiddens = torch.cat([cls_hiddens, skip_hiddens], dim=1)
else:
raise NotImplementedError
attention_mask = self.enc.get_extended_attention_mask(
inputs['attention_mask'],
inputs['attention_mask'].shape,
inputs['attention_mask'].device
)
for layer in self.dec:
layer_out = layer(
hiddens,
attention_mask,
)
# layer_out == (layer_out,) + attention_weights
hiddens = layer_out[0]
dec_output_preds = self.dec_mlm_head(hiddens)
return enc_output.prediction_logits, dec_output_preds
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
inputs = tokenizer('foo bar', truncation=True, return_tensors="pt")
condensor = CondensorPretraining()
enc_output, dec_output = condensor(inputs)
print(enc_output.shape)
print(dec_output.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment