Skip to content

Instantly share code, notes, and snippets.

@arteagac
Created November 16, 2023 23:38
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save arteagac/5cfb018d605f1cb809fe8c561896f4dc to your computer and use it in GitHub Desktop.
Save arteagac/5cfb018d605f1cb809fe8c561896f4dc to your computer and use it in GitHub Desktop.
Expand BERT beyond 512 tokens
# Load the model
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
## EXPAND POSITION EMBEDDINGS TO 1024 TOKENS
max_length = 1024
tokenizer.model_max_length = max_length
model.config.max_position_embeddings = max_length
model.base_model.embeddings.position_ids = torch.arange(max_length).expand((1, -1))
model.base_model.embeddings.token_type_ids = torch.zeros(max_length).expand((1, -1))
orig_pos_emb = model.base_model.embeddings.position_embeddings.weight
model.base_model.embeddings.position_embeddings.weight = torch.nn.Parameter(torch.cat((orig_pos_emb, orig_pos_emb)))
## TEST THE MODEL IN A QUESTION ANSWERING TASK
question = "Where is the largest airport in the united states?"
# Simulate initial ~600 tokens by repeating 60 times a phrase of length 10
simul_tokens = " ".join(60*["This phrases simulates the initial 600 tokens by simple repetition. "])
# Place the answer to the question at the end of the 600 simulated tokens.
context = simul_tokens + "The largest airport in the United States is located in Atlanta."
# Use the question answering model
inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
outputs = model(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
tokenizer.decode(predict_answer_tokens)
# OUTPUT: atlanta
# The correct output demonstrates BERT was able to attend beyond 512 tokens
# thanks to the expansion in position embeddings.
@deshwalmahesh
Copy link

deshwalmahesh commented Nov 27, 2023

Anyone looking at this, you can check with this to test that nuking base pre trained model with 1024 gives you error: RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1

BUT, if you change the model to "microsoft/deberta-v3-base", and use even 8096, it won't give you errors because of the type of attention and Pos Emb used

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")


## TEST THE MODEL IN A QUESTION ANSWERING TASK
question = "Where is the largest airport in the united states?"

# Simulate initial ~600 tokens by repeating 60 times a phrase of length 10
simul_tokens = " ".join(60*["This phrases simulates the initial 600 tokens by simple repetition. "])
# Place the answer to the question at the end of the 600 simulated tokens.
context = simul_tokens + "The largest airport in the United States is located in Atlanta."

# Use the question answering model
inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length = 1024, padding = "max_length")
outputs = model(**inputs)    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment