Skip to content

Instantly share code, notes, and snippets.

@ebsmothers
Created June 27, 2024 16:54
Show Gist options
  • Save ebsmothers/707ee5a260b0e25b820a19b2e634e921 to your computer and use it in GitHub Desktop.
Save ebsmothers/707ee5a260b0e25b820a19b2e634e921 to your computer and use it in GitHub Desktop.
%%timeit
seq_lens = [100] * 160
max_seq_len = 16384
num_samples_in_pack = len(seq_lens)
block_attn_masks = []
total_seq_len = 0
for i, seq_len in enumerate(seq_lens):
total_seq_len += seq_len
# Append lower triangular matrix for causal mask
block_attn_masks.append(
torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
)
# If we're at the last sample and the total seq len is less than the max seq len,
# we need to pad with identity matrix for the remainder
if i == num_samples_in_pack - 1 and total_seq_len < max_seq_len:
block_attn_masks.append(
torch.eye(
max_seq_len - total_seq_len,
max_seq_len - total_seq_len,
dtype=torch.bool,
)
)
mask = torch.block_diag(*block_attn_masks)
>>> 5.9 ms ± 175 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
seq_lens = [100] * 160
max_seq_len = 16384
num_samples_in_pack = len(seq_lens)
padding = max_seq_len - sum(seq_lens)
is_padding = [False] * len(seq_lens) + [True]
if padding > 0:
seq_lens = seq_lens + [padding]
is_padding = is_padding + [True]
mask = torch.block_diag(*[torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) if not is_padding else torch.eye(seq_len, seq_len, dtype=torch.bool) for seq_len in seq_lens])
>>> 3.99 ms ± 3.51 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment