-
-
Save justheuristic/9751e02a2a5604a98a4fe0b6b688e808 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch import nn | |
from transformers import BloomConfig, AutoTokenizer | |
from transformers.models.bloom.modeling_bloom import build_alibi_tensor as build_alibi_tensor_old | |
from transformers.models.bloom.modeling_bloom import pre_process_alibi_for_pad as pre_process_alibi_for_pad_old | |
from transformers.models.bloom.modeling_bloom import BloomAttention as BloomAttentionOld | |
def build_alibi_tensor( | |
max_seq_len: int, num_attention_heads: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device('cpu') | |
) -> torch.Tensor: | |
closest_power_of_2 = 2 ** math.floor(math.log2(num_attention_heads)) | |
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), | |
device=device, dtype=torch.float32) | |
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) | |
slopes = torch.pow(base, powers) | |
if closest_power_of_2 != num_attention_heads: | |
extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), | |
device=device, dtype=torch.float32) | |
num_remaining_heads = min(closest_power_of_2, num_attention_heads - closest_power_of_2) | |
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) | |
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) | |
lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32) | |
return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype) | |
def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor): | |
(num_attention_heads, _one, max_seq_len), (batch_size, _max_seq_len) = alibi.shape, attention_mask.shape | |
assert _one == 1 and max_seq_len == _max_seq_len | |
unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1) | |
# ^-- [batch, max_len], values correspond to element indices after removing padding | |
return alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) | |
class BloomAttentionGood(BloomAttentionOld): | |
def forward( | |
self, | |
hidden_states, | |
layer_past=None, | |
attention_mask=None, | |
head_mask=None, | |
use_cache=False, | |
output_attentions=False, | |
): | |
# apply preprocessing if the input is padded | |
max_seq_len = hidden_states.shape[1] + (layer_past[0].shape[1] if layer_past is not None else 0) | |
alibi = build_alibi_tensor(max_seq_len, self.num_heads, hidden_states.dtype, hidden_states.device) | |
if attention_mask is not None: # <-- do not check for 0 in mask to avoid host-device sync | |
alibi = pre_process_alibi_for_pad(alibi, attention_mask) | |
bias = self.query_key_value.bias | |
mixed_x_layer = nn.functional.linear(hidden_states, self.query_key_value.weight, bias) | |
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim] | |
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim) | |
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) | |
# [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim] | |
(query_layer, key_layer, value_layer) = mixed_x_layer.split(self.head_dim, dim=-1) | |
if layer_past is not None: | |
past_key, past_value = layer_past | |
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) | |
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) | |
# ^-- todo this is causes insane fragmentation | |
if use_cache is True: | |
present = (key_layer, value_layer) | |
else: | |
present = None | |
# [batch_size, head_dim, q_length, k_length] | |
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) | |
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] | |
query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1) | |
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] | |
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1) | |
# Raw attention scores. [batch_size * num_heads, q_length, k_length] | |
matmul_result = torch.baddbmm( | |
alibi, # <-- note: i aint sure, but it appears we don't need to slice alibi as before | |
query_layer.transpose(1, 0), | |
key_layer.permute(1, 2, 0), | |
beta=1.0 / self.layer_number, | |
alpha=1.0 / self.norm_factor, | |
) | |
############################################################## | |
# TODO: THERE ARE NO MORE CHANGES AFTER THIS LINE, REMOVE IT # | |
############################################################## | |
# change view to [batch_size, num_heads, q_length, k_length] | |
attention_scores = matmul_result.view(*output_size) | |
# attention scores and attention mask [b, np, sq, sk] | |
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2]) | |
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to( | |
value_layer.dtype | |
) | |
attention_probs = self.attention_dropout(attention_probs) | |
if head_mask is not None: | |
attention_probs = attention_probs * head_mask | |
# context layer shape: [batch_size, num_heads, q_length, head_dim] | |
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3)) | |
# change view [k_length, batch_size x num_heads, head_dim] | |
value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1) | |
# change view [batch_size x num_heads, q_length, k_length] | |
attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) | |
# matmul: [batch_size * num_heads, q_length, head_dim] | |
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1)) | |
# change view [batch_size, num_heads, q_length, head_dim] | |
context_layer = context_layer.view(*output_size) | |
# [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim] | |
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() | |
# [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size] | |
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) | |
context_layer = context_layer.view(*new_context_layer_shape) | |
# Output. [q_length, batch_size, hidden_size] | |
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 | |
if self.pretraining_tp > 1 and self.slow_but_exact: | |
slices = context_layer.shape[-1] / self.pretraining_tp | |
output_tensor = torch.zeros_like(context_layer) | |
for i in range(self.pretraining_tp): | |
output_tensor = output_tensor + nn.functional.linear( | |
context_layer[:, :, int(i * slices) : int((i + 1) * slices)], | |
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], | |
) | |
else: | |
output_tensor = nn.functional.linear(context_layer, self.dense.weight) | |
output_tensor = output_tensor | |
output_bias = self.dense.bias | |
output = output_tensor.transpose(1, 0) | |
outputs = (output, present) | |
if output_attentions: | |
outputs += (attention_probs,) | |
return outputs, output_bias | |
if __name__ == '__main__': | |
print("Test 1: build_alibi_tensor") | |
lengths = (1, 20, 31, 512, 1337, 2048) | |
for max_seq_len in lengths: | |
for dtype in torch.float32, torch.bfloat16, torch.float16: | |
for num_attention_heads in range(1, 256): | |
ours = build_alibi_tensor(max_seq_len, num_attention_heads, dtype) | |
ref = build_alibi_tensor_old(max_seq_len, num_attention_heads, torch.float32).to(dtype) | |
atol = 1e-7 if dtype == torch.float32 else 3e-5 | |
assert torch.allclose(ref, ours, atol=atol, rtol=999), (max_seq_len, num_attention_heads, dtype) | |
print("Passed build_alibi_tensor!") | |
print("Test 2: pre_process_alibi_for_pad") | |
num_attention_heads = 112 | |
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 0, 0, 0], | |
[0, 1, 1, 1, 0, 0, 0], | |
[0, 0, 1, 1, 1, 1, 1], | |
[1, 0, 0, 1, 1, 0, 1], | |
[0, 0, 0, 0, 0, 0, 0]], | |
dtype=torch.bool) | |
alibi = build_alibi_tensor(attention_mask.shape[1], num_attention_heads) | |
padded_alibi_ours = pre_process_alibi_for_pad(alibi, attention_mask) | |
for i in range(len(attention_mask)): | |
attention_mask_i = attention_mask[i:i+1, :] | |
padded_alibi_ref_i = pre_process_alibi_for_pad_old(alibi, attention_mask_i, num_attention_heads) | |
assert torch.allclose(padded_alibi_ref_i * attention_mask_i, padded_alibi_ours[:, i:i+1] * attention_mask_i) | |
print("Passed pre_process_alibi_for_pad!") | |
print("Test 3: attention layer") | |
config = BloomConfig.from_pretrained("bigscience/bloom-6b3") | |
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-6b3") | |
layer = BloomAttentionOld(config, layer_number=11).train(False) | |
attention_mask = torch.tensor([[0, 1, 1, 0, 1, 1, 1]]).bool() | |
hidden_states = torch.randn(1, attention_mask.shape[1], config.hidden_size) | |
alibi = build_alibi_tensor_old(hidden_states.shape[1], config.num_attention_heads, dtype=torch.float32) | |
ref = layer.forward(hidden_states=hidden_states, alibi=alibi, attention_mask=attention_mask) | |
ours = BloomAttentionGood.forward(layer, hidden_states=hidden_states, attention_mask=attention_mask) | |
assert torch.allclose(ours[0][0] + ours[1], ref[0][0] + ref[1], atol=1e-6, rtol=999) | |
print("Passed attention layer!") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment