Skip to content

Instantly share code, notes, and snippets.

@edgartanaka
Last active May 18, 2022 10:06
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edgartanaka/0d69b50e39f96cb0738f9808d48158a2 to your computer and use it in GitHub Desktop.
Save edgartanaka/0d69b50e39f96cb0738f9808d48158a2 to your computer and use it in GitHub Desktop.
Converting MBart to Longformer
import argparse
import logging
import os
import copy
from transformers import MBart50Tokenizer
from transformers import MBartForConditionalGeneration, AutoTokenizer
# from transformers.modeling_bart import shift_tokens_right
from longformer_encoder_decoder import LongformerSelfAttentionForMBart, LongformerEncoderDecoderConfig
from longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def create_long_model(
save_model_to,
base_model,
tokenizer_name_or_path,
attention_window,
max_pos
):
# model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
# tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
model = MBartForConditionalGeneration.from_pretrained(base_model)
tokenizer = MBart50Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos)
config = LongformerEncoderDecoderConfig.from_pretrained(base_model)
model.config = config
# in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention
# expects attention_probs_dropout_prob, so set it here
config.attention_probs_dropout_prob = config.attention_dropout
config.architectures = ['LongformerEncoderDecoderForConditionalGeneration', ]
# extend position embeddings
tokenizer.model_max_length = max_pos
tokenizer.init_kwargs['model_max_length'] = max_pos
current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape
assert current_max_pos == config.max_position_embeddings + 2
config.max_encoder_position_embeddings = max_pos
config.max_decoder_position_embeddings = config.max_position_embeddings
del config.max_position_embeddings
max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
assert max_pos >= current_max_pos
# allocate a larger position embedding matrix for the encoder
new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
# copy position embeddings over and over to initialize the new position embeddings
k = 2
step = current_max_pos - 2
while k < max_pos - 1:
new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]
k += step
model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed
# allocate a larger position embedding matrix for the decoder
# new_decoder_pos_embed = model.model.decoder.embed_positions.weight.new_empty(max_pos, embed_size)
# # copy position embeddings over and over to initialize the new position embeddings
# k = 2
# step = current_max_pos - 2
# while k < max_pos - 1:
# new_decoder_pos_embed[k:(k + step)] = model.model.decoder.embed_positions.weight[2:]
# k += step
# model.model.decoder.embed_positions.weight.data = new_decoder_pos_embed
# replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`
config.attention_window = [attention_window] * config.num_hidden_layers
config.attention_dilation = [1] * config.num_hidden_layers
for i, layer in enumerate(model.model.encoder.layers):
longformer_self_attn_for_bart = LongformerSelfAttentionForMBart(config, layer_id=i)
longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj
longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj
longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj
longformer_self_attn_for_bart.longformer_self_attn.query_global = copy.deepcopy(layer.self_attn.q_proj)
longformer_self_attn_for_bart.longformer_self_attn.key_global = copy.deepcopy(layer.self_attn.k_proj)
longformer_self_attn_for_bart.longformer_self_attn.value_global = copy.deepcopy(layer.self_attn.v_proj)
longformer_self_attn_for_bart.output = layer.self_attn.out_proj
layer.self_attn = longformer_self_attn_for_bart
# save model
logger.info(f'saving model to {save_model_to}')
model.save_pretrained(save_model_to)
tokenizer.save_pretrained(save_model_to)
return model, tokenizer
# def mask_test(args):
# tokenizer = MBartTokenizer.from_pretrained(args.save_model_to)
# TXT = "My friends are <mask> but they eat too many carbs."
# model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(args.save_model_to)
# model.model.encoder.config.gradient_checkpointing = True
# model.model.decoder.config.gradient_checkpointing = True
# data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048)
# input_ids = data['input_ids']
# attention_mask = data['attention_mask']
# decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id)
# logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0]
# masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
# probs = logits[0, masked_index].softmax(dim=0)
# values, predictions = probs.topk(5)
# print(tokenizer.convert_ids_to_tokens(predictions))
def summary_test(args):
tokenizer = MBart50Tokenizer.from_pretrained(args.save_model_to)
# TXT = "My friends are <mask> but they eat too many carbs."
model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(args.save_model_to)
# what are these doing?!
# I discommented them because I think they are going to fix the problem of arguments in forward function
model.model.encoder.config.gradient_checkpointing = True
model.model.decoder.config.gradient_checkpointing = True
# ARTICLE_TO_SUMMARIZE = "My friends are cool, but they eat too much carbs."
with open('article_es.txt', 'r') as file:
ARTICLE_TO_SUMMARIZE = file.read().replace('\n', '')
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=4096, return_tensors='pt', padding="max_length", truncation=True)
# Generate Summary
print(inputs['input_ids'])
print('length input ids:', inputs['input_ids'].size())
print('w = ', model.model.config.attention_window)
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=50, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
def main():
parser = argparse.ArgumentParser(description="Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention")
parser.add_argument(
'--base_model',
type=str,
default='facebook/bart-large',
help='The name or path of the base model you want to convert'
)
parser.add_argument(
'--tokenizer_name_or_path',
type=str,
default='facebook/bart-large',
help='The name or path of the tokenizer'
)
parser.add_argument(
'--save_model_to',
type=str,
required=True,
help='The path to save the converted model'
)
parser.add_argument(
'--attention_window',
type=int,
default=512,
help='attention window size for longformer self attention (one sided)'
)
parser.add_argument(
'--max_pos',
type=int,
default=4096,
help='maximum encoder positions'
)
args = parser.parse_args()
if not os.path.exists(args.save_model_to):
os.mkdir(args.save_model_to)
create_long_model(
save_model_to=args.save_model_to,
base_model=args.base_model,
tokenizer_name_or_path=args.tokenizer_name_or_path,
attention_window=args.attention_window,
max_pos=args.max_pos
)
summary_test(args)
if __name__ == "__main__":
main()
from typing import List, Optional, Tuple, Dict
from torch import nn, Tensor
# from longformer.longformer import LongformerSelfAttention
from transformers import LongformerSelfAttention
from transformers import MBartConfig, MBartForConditionalGeneration
from transformers.models.mbart.modeling_mbart import MBartLearnedPositionalEmbedding
class LongformerEncoderDecoderForConditionalGeneration(MBartForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
if config.attention_mode == 'n2':
pass # do nothing, use BertSelfAttention instead
else:
self.model.encoder.embed_positions = MBartLearnedPositionalEmbedding(4096, 1024)
for i, layer in enumerate(self.model.encoder.layers):
layer.self_attn = LongformerSelfAttentionForMBart(config, layer_id=i)
class LongformerEncoderDecoderConfig(MBartConfig):
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
gradient_checkpointing: bool = False, **kwargs):
"""
Args:
attention_window: list of attention window sizes of length = number of layers.
window size = number of attention locations on each side.
For an affective window size of 512, use `attention_window=[256]*num_layers`
which is 256 on each side.
attention_dilation: list of attention dilation of length = number of layers.
attention dilation of `1` means no dilation.
autoregressive: do autoregressive attention or have attention of both sides
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
"""
super().__init__(**kwargs)
self.attention_window = attention_window
self.attention_dilation = attention_dilation
self.autoregressive = autoregressive
self.attention_mode = attention_mode
self.gradient_checkpointing = gradient_checkpointing
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
class LongformerSelfAttentionForMBart(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.embed_dim = config.d_model
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
self.output = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states=None,
attention_mask=None,
layer_head_mask=None,
output_attentions=False
) -> Tuple[Tensor, Optional[Tensor]]:
# NEW
outputs = self.longformer_self_attn(
hidden_states=hidden_states, # I'm guessing I just need to pass
attention_mask=attention_mask, # I'm guessing I just need to pass
layer_head_mask=layer_head_mask, # I'm guessing I just need to pass
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=output_attentions,
)
attn_output = self.output(outputs[0].transpose(0, 1))
return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)
python convert_mbart_to_longformer.py --save_model_to model_dir \
--base_model facebook/mbart-large-50 \
--tokenizer_name_or_path facebook/mbart-large-50%
@dodo-robot
Copy link

dodo-robot commented May 18, 2022

Hey man, I'm not sure you're still looking to solve this but I was struggling with the same problem and thanks to this guy https://github.com/Taeksu-Kim/longformer_kobart I've managed to generate summaries with my brand new longformer Bart

2 things needs to be changed:

  • in the LongformerEncoderDecoderForConditionalGeneration you need to add the MBart positional embedding to the decoder as well
  • the forward in the self attention layer according to Taeksu-Kim needs to be like this in order not to raise the "too many values to unpack " error:
def forward(
       self,
       hidden_states: torch.Tensor,
       key_value_states: Optional[torch.Tensor] = None,
       past_key_value: Optional[Tuple[torch.Tensor]] = None,
       attention_mask: Optional[torch.Tensor] = None,
       layer_head_mask: Optional[torch.Tensor] = None,
       output_attentions: bool = False,
   ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

       is_cross_attention = key_value_states is not None
       bsz, tgt_len, embed_dim = hidden_states.size()

       attention_mask = attention_mask.squeeze(dim=1)
       attention_mask = attention_mask[:,0]

       is_index_masked = attention_mask < 0
       is_index_global_attn = attention_mask > 0
       is_global_attn = is_index_global_attn.flatten().any().item()

       outputs = self.longformer_self_attn(
           hidden_states,
           attention_mask=attention_mask,
           layer_head_mask=None,
           is_index_masked=is_index_masked,
           is_index_global_attn=is_index_global_attn,
           is_global_attn=is_global_attn,
           output_attentions=output_attentions,
       )

       attn_output = self.output(outputs[0])

       return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None, None)
       ```


best regards

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment