Created
June 27, 2024 16:54
-
-
Save ebsmothers/707ee5a260b0e25b820a19b2e634e921 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%timeit | |
seq_lens = [100] * 160 | |
max_seq_len = 16384 | |
num_samples_in_pack = len(seq_lens) | |
block_attn_masks = [] | |
total_seq_len = 0 | |
for i, seq_len in enumerate(seq_lens): | |
total_seq_len += seq_len | |
# Append lower triangular matrix for causal mask | |
block_attn_masks.append( | |
torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) | |
) | |
# If we're at the last sample and the total seq len is less than the max seq len, | |
# we need to pad with identity matrix for the remainder | |
if i == num_samples_in_pack - 1 and total_seq_len < max_seq_len: | |
block_attn_masks.append( | |
torch.eye( | |
max_seq_len - total_seq_len, | |
max_seq_len - total_seq_len, | |
dtype=torch.bool, | |
) | |
) | |
mask = torch.block_diag(*block_attn_masks) | |
>>> 5.9 ms ± 175 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
%%timeit | |
seq_lens = [100] * 160 | |
max_seq_len = 16384 | |
num_samples_in_pack = len(seq_lens) | |
padding = max_seq_len - sum(seq_lens) | |
is_padding = [False] * len(seq_lens) + [True] | |
if padding > 0: | |
seq_lens = seq_lens + [padding] | |
is_padding = is_padding + [True] | |
mask = torch.block_diag(*[torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) if not is_padding else torch.eye(seq_len, seq_len, dtype=torch.bool) for seq_len in seq_lens]) | |
>>> 3.99 ms ± 3.51 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment