Skip to content

Instantly share code, notes, and snippets.

@victordibia
Last active June 22, 2020 03:00
Show Gist options
  • Save victordibia/5394dcc919fc7e691c973f11703f737e to your computer and use it in GitHub Desktop.
Save victordibia/5394dcc919fc7e691c973f11703f737e to your computer and use it in GitHub Desktop.
How to Explain HuggingFace BERT for Question Answering NLP Models with TF 2.0 GradientTape
def get_gradient(question, context, model, tokenizer):
"""Return gradient of input (question) wrt to model output span prediction
Args:
question (str): text of input question
context (str): text of question context/passage
model (QA model): Hugging Face BERT model for QA transformers.modeling_tf_distilbert.TFDistilBertForQuestionAnswering, transformers.modeling_tf_bert.TFBertForQuestionAnswering
tokenizer (tokenizer): transformers.tokenization_bert.BertTokenizerFast
Returns:
(tuple): (gradients, token_words, token_types, answer_text)
"""
embedding_matrix = model.bert.embeddings.word_embeddings
encoded_tokens = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="tf")
token_ids = list(encoded_tokens["input_ids"].numpy()[0])
vocab_size = embedding_matrix.get_shape()[0]
# convert token ids to one hot. We can't differentiate wrt to int token ids hence the need for one hot representation
token_ids_tensor = tf.constant([token_ids], dtype='int32')
token_ids_tensor_one_hot = tf.one_hot(token_ids_tensor, vocab_size)
with tf.GradientTape(watch_accessed_variables=False) as tape:
# (i) watch input variable
tape.watch(token_ids_tensor_one_hot)
# multiply input model embedding matrix; allows us do backprop wrt one hot input
inputs_embeds = tf.matmul(token_ids_tensor_one_hot,embedding_matrix)
# (ii) get prediction
start_scores,end_scores = model({"inputs_embeds": inputs_embeds, "token_type_ids": encoded_tokens["token_type_ids"], "attention_mask": encoded_tokens["attention_mask"] })
answer_start, answer_end = get_best_start_end_position(start_scores, end_scores)
start_output_mask = get_correct_span_mask(answer_start, len(token_ids))
end_output_mask = get_correct_span_mask(answer_end, len(token_ids))
# zero out all predictions outside of the correct span positions; we want to get gradients wrt to just these positions
predict_correct_start_token = tf.reduce_sum(start_scores * start_output_mask)
predict_correct_end_token = tf.reduce_sum(end_scores * end_output_mask)
# (iii) get gradient of input with respect to both start and end output
gradient_non_normalized = tf.norm(
tape.gradient([predict_correct_start_token, predict_correct_end_token], token_ids_tensor_one_hot),axis=2)
# (iv) normalize gradient scores and return them as "explanations"
gradient_tensor = (
gradient_non_normalized /
tf.reduce_max(gradient_non_normalized)
)
gradients = gradient_tensor[0].numpy().tolist()
token_words = tokenizer.convert_ids_to_tokens(token_ids)
token_types = list(encoded_tokens["token_type_ids"].numpy()[0])
answer_text = tokenizer.convert_tokens_to_string(token_ids[answer_start:answer_end])
return gradients, token_words, token_types,answer_text
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment