Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Created June 17, 2022 05:00
Show Gist options
  • Save justheuristic/9751e02a2a5604a98a4fe0b6b688e808 to your computer and use it in GitHub Desktop.
Save justheuristic/9751e02a2a5604a98a4fe0b6b688e808 to your computer and use it in GitHub Desktop.
import math
import torch
from torch import nn
from transformers import BloomConfig, AutoTokenizer
from transformers.models.bloom.modeling_bloom import build_alibi_tensor as build_alibi_tensor_old
from transformers.models.bloom.modeling_bloom import pre_process_alibi_for_pad as pre_process_alibi_for_pad_old
from transformers.models.bloom.modeling_bloom import BloomAttention as BloomAttentionOld
def build_alibi_tensor(
max_seq_len: int, num_attention_heads: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device('cpu')
) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(num_attention_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
device=device, dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_attention_heads:
extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=device, dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2, num_attention_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
(num_attention_heads, _one, max_seq_len), (batch_size, _max_seq_len) = alibi.shape, attention_mask.shape
assert _one == 1 and max_seq_len == _max_seq_len
unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
# ^-- [batch, max_len], values correspond to element indices after removing padding
return alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1)
class BloomAttentionGood(BloomAttentionOld):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
# apply preprocessing if the input is padded
max_seq_len = hidden_states.shape[1] + (layer_past[0].shape[1] if layer_past is not None else 0)
alibi = build_alibi_tensor(max_seq_len, self.num_heads, hidden_states.dtype, hidden_states.device)
if attention_mask is not None: # <-- do not check for 0 in mask to avoid host-device sync
alibi = pre_process_alibi_for_pad(alibi, attention_mask)
bias = self.query_key_value.bias
mixed_x_layer = nn.functional.linear(hidden_states, self.query_key_value.weight, bias)
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = mixed_x_layer.split(self.head_dim, dim=-1)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
# ^-- todo this is causes insane fragmentation
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
# [batch_size, head_dim, q_length, k_length]
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
matmul_result = torch.baddbmm(
alibi, # <-- note: i aint sure, but it appears we don't need to slice alibi as before
query_layer.transpose(1, 0),
key_layer.permute(1, 2, 0),
beta=1.0 / self.layer_number,
alpha=1.0 / self.norm_factor,
)
##############################################################
# TODO: THERE ARE NO MORE CHANGES AFTER THIS LINE, REMOVE IT #
##############################################################
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
value_layer.dtype
)
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# context layer shape: [batch_size, num_heads, q_length, head_dim]
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
# change view [k_length, batch_size x num_heads, head_dim]
value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
# change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = context_layer.view(*output_size)
# [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [q_length, batch_size, hidden_size]
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = context_layer.shape[-1] / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = nn.functional.linear(context_layer, self.dense.weight)
output_tensor = output_tensor
output_bias = self.dense.bias
output = output_tensor.transpose(1, 0)
outputs = (output, present)
if output_attentions:
outputs += (attention_probs,)
return outputs, output_bias
if __name__ == '__main__':
print("Test 1: build_alibi_tensor")
lengths = (1, 20, 31, 512, 1337, 2048)
for max_seq_len in lengths:
for dtype in torch.float32, torch.bfloat16, torch.float16:
for num_attention_heads in range(1, 256):
ours = build_alibi_tensor(max_seq_len, num_attention_heads, dtype)
ref = build_alibi_tensor_old(max_seq_len, num_attention_heads, torch.float32).to(dtype)
atol = 1e-7 if dtype == torch.float32 else 3e-5
assert torch.allclose(ref, ours, atol=atol, rtol=999), (max_seq_len, num_attention_heads, dtype)
print("Passed build_alibi_tensor!")
print("Test 2: pre_process_alibi_for_pad")
num_attention_heads = 112
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 0],
[0, 0, 1, 1, 1, 1, 1],
[1, 0, 0, 1, 1, 0, 1],
[0, 0, 0, 0, 0, 0, 0]],
dtype=torch.bool)
alibi = build_alibi_tensor(attention_mask.shape[1], num_attention_heads)
padded_alibi_ours = pre_process_alibi_for_pad(alibi, attention_mask)
for i in range(len(attention_mask)):
attention_mask_i = attention_mask[i:i+1, :]
padded_alibi_ref_i = pre_process_alibi_for_pad_old(alibi, attention_mask_i, num_attention_heads)
assert torch.allclose(padded_alibi_ref_i * attention_mask_i, padded_alibi_ours[:, i:i+1] * attention_mask_i)
print("Passed pre_process_alibi_for_pad!")
print("Test 3: attention layer")
config = BloomConfig.from_pretrained("bigscience/bloom-6b3")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-6b3")
layer = BloomAttentionOld(config, layer_number=11).train(False)
attention_mask = torch.tensor([[0, 1, 1, 0, 1, 1, 1]]).bool()
hidden_states = torch.randn(1, attention_mask.shape[1], config.hidden_size)
alibi = build_alibi_tensor_old(hidden_states.shape[1], config.num_attention_heads, dtype=torch.float32)
ref = layer.forward(hidden_states=hidden_states, alibi=alibi, attention_mask=attention_mask)
ours = BloomAttentionGood.forward(layer, hidden_states=hidden_states, attention_mask=attention_mask)
assert torch.allclose(ours[0][0] + ours[1], ref[0][0] + ref[1], atol=1e-6, rtol=999)
print("Passed attention layer!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment