Created
November 16, 2023 23:38
-
-
Save arteagac/5cfb018d605f1cb809fe8c561896f4dc to your computer and use it in GitHub Desktop.
Expand BERT beyond 512 tokens
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load the model | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad") | |
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad") | |
## EXPAND POSITION EMBEDDINGS TO 1024 TOKENS | |
max_length = 1024 | |
tokenizer.model_max_length = max_length | |
model.config.max_position_embeddings = max_length | |
model.base_model.embeddings.position_ids = torch.arange(max_length).expand((1, -1)) | |
model.base_model.embeddings.token_type_ids = torch.zeros(max_length).expand((1, -1)) | |
orig_pos_emb = model.base_model.embeddings.position_embeddings.weight | |
model.base_model.embeddings.position_embeddings.weight = torch.nn.Parameter(torch.cat((orig_pos_emb, orig_pos_emb))) | |
## TEST THE MODEL IN A QUESTION ANSWERING TASK | |
question = "Where is the largest airport in the united states?" | |
# Simulate initial ~600 tokens by repeating 60 times a phrase of length 10 | |
simul_tokens = " ".join(60*["This phrases simulates the initial 600 tokens by simple repetition. "]) | |
# Place the answer to the question at the end of the 600 simulated tokens. | |
context = simul_tokens + "The largest airport in the United States is located in Atlanta." | |
# Use the question answering model | |
inputs = tokenizer(question, context, return_tensors="pt", truncation=True) | |
outputs = model(**inputs) | |
answer_start_index = outputs.start_logits.argmax() | |
answer_end_index = outputs.end_logits.argmax() | |
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] | |
tokenizer.decode(predict_answer_tokens) | |
# OUTPUT: atlanta | |
# The correct output demonstrates BERT was able to attend beyond 512 tokens | |
# thanks to the expansion in position embeddings. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Anyone looking at this, you can check with this to test that nuking base pre trained model with
1024
gives you error:RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1
BUT, if you change the model to
"microsoft/deberta-v3-base"
, and use even8096
, it won't give you errors because of the type of attention and Pos Emb used