We use
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)
because of pytorch/pytorch#108108 (wrong mask alignment for is_causal
in some cases, including decoding with KV cache).
In the prefill we have is_causal=True
hard-coded:
[2024-04-08 11:57:34,687] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] attn_output_48 = torch._C._nn.scaled_dot_product_attention(query_states_38, key_states_96, value_states_83, attn_mask = None, dropout_p = 0.0, is_causal = True); query_states_38 = None
Recompilation in the decoding:
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /home/felix/transformers/src/transformers/models/llama/modeling_llama.py:1163
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['input_ids']' size mismatch at index 1. expected 4, actual 1
We then have is_causal=False
hard-coded:
[2024-04-08 11:58:21,913] [0/1] torch._dynamo.output_graph.__graph_code: [DEBUG] attn_output_24 = torch._C._nn.scaled_dot_product_attention(query_states_20, key_states_66, value_states_59, attn_mask = None, dropout_p = 0.0, is_causal = False); query_states_20 = None