Created
May 6, 2021 21:56
-
-
Save yashbonde/c22a1fd5026e27f587801766d70ae6ef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# @yashbonde | |
# | |
# In this quick script we are trying to solve sharding problem: | |
# often in very large datasets there is no way to tokenize everything and store | |
# them. Considering the CLM datasets we have a fixed dataset where each row | |
# has dynamic number of tokens. A dummy looks like follows: | |
# | |
# j n sequence (w/o EOT = 42) | |
# [0] [15] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], | |
# [1] [13] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], | |
# [2] [11] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | |
# [3] [13] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], | |
# [4] [15] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], | |
# [5] [ 8] [0, 1, 2, 3, 4, 5, 6], | |
# [6] [14] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], | |
# [7] [ 8] [0, 1, 2, 3, 4, 5, 6], | |
# [8] [11] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | |
# [9] [10] [0, 1, 2, 3, 4, 5, 6, 7, 8] | |
# | |
# During initialisation we provide | |
# a) seqlen: Size of each output sequence | |
# b) stride: Difference between two consecutive samples | |
# | |
# When training the model we train on continuous spans (size = seqlen) | |
# and these spans are obtained by merging multiple sequences or from the | |
# same sequence itself. | |
# | |
# - for seqlen = 10 and stride = 10 | |
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 42, 0, 1, 2, 3, 4], ... | |
# stride = seqlen ensures there is no overlap in the sequences | |
# | |
# - for seqlen = 10, and stride = 5 | |
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 42, 0, 1, 2, 3, 4], ... | |
# [5, 6, 7, 8, 9, 10, 11, 12, 13, 420], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ... | |
# notice how sequences have overlaps | |
# | |
# TASK: given a list of lists (see above) called `ds`, `seqlen` and `stride` | |
# a) can you find the total number of samples (l) in the dataset | |
# b) given any i <= l can you get me the correct sequence | |
# | |
# Code is given to you, however there is a bug and it is making me feel (ಥ﹏ಥ) | |
# can you solve this? | |
import numpy as np | |
# create a dataset mimicking the sentences dataset | |
ds = [np.arange(x).tolist() for x in np.random.randint(low = 5, high = 15, size = (10))] | |
lengths = np.array([len(x) + 1 for x in ds]) # plus 1 for <|endoftext|> token | |
length_cumsum = np.cumsum(lengths) | |
prin(ds, lengths, length_cumsum) | |
stride = 4 | |
seqlen = 8 | |
total_tokens = sum(lengths) | |
# A) Obtain the total number of samples | |
_len = 0 | |
_start = 0 | |
start_wise_count = [0] | |
while _start < seqlen: | |
# with each iteration shift the start by stride and so total tokens | |
# reduce by stride(s) multiples (handled as _start) and this shift | |
# continues while the total moves taken is < seqlen, because then | |
# you'll enter into the second patch and can overlap | |
# we also subtract incase the there | |
tokens_flat = total_tokens - _start | |
this_batch = tokens_flat // seqlen - int(tokens_flat % seqlen != 0) | |
_len += this_batch | |
_start += stride | |
start_wise_count.append(this_batch) | |
start_wise_count = np.cumsum(start_wise_count) | |
print(start_wise_count) | |
print("Total Samples:", _len) | |
# B) given `i` find the correct sequence | |
# in our representation indices are stride based, ie. if start_wise_counts = [13, 24] | |
# meaning first 13 samples belong to _start = 0 and next 11 belong to _start = stride | |
EOT = 420 | |
for i in range(start_wise_count[-1]): | |
stride_mult = (start_wise_count > i).argmax() - 1 # this is stride multiplier | |
_start = stride * stride_mult # shift in start token to be added to token_id | |
stride_bucket_start_idx = i - start_wise_count[stride_mult] # this is the starting idx in that particular stride | |
token_start = stride_bucket_start_idx * seqlen + _start | |
input_ids = [] | |
while len(input_ids) < seqlen: | |
# pick the required sequence from dataset, note add the EOT token | |
# this ds[seq_idx] is replaced by tokenizer.encode(text) | |
seq_idx = (length_cumsum > token_start).argmax() | |
seq = ds[seq_idx] + [EOT] | |
# but in general the starting point in the sequence is somewhere in the middle | |
# so we find that in seq_start | |
seq_start = len(seq) - (length_cumsum[seq_idx] - token_start) | |
# print("len_cumsum:", length_cumsum[seq_idx] , "| token_start:", token_start, "| seq_idx:", seq_idx, "| seq:", seq, "| len:", len(input_ids), "| seq_start:", seq_start) | |
seq_to_add = seq[seq_start:] | |
input_ids.extend(seq_to_add) | |
token_start += len(seq_to_add) | |
input_ids = input_ids[:seqlen] | |
print(i, "\t", input_ids) | |
# For the data given at top and code given below, we obtain the following sequences. | |
# However note that this is incorrect because after i = 12 there should have been | |
# 13 = [7, 8, 9, 420, 0, 1, 2, 3, 4, 5] | |
# | |
# i sequence | |
# 0 [0, 1, 2, 3, 4, 5, 6, 7] | |
# 1 [8, 9, 10, 11, 12, 13, 420, 0] | |
# 2 [1, 2, 3, 4, 5, 6, 7, 8] | |
# 3 [9, 10, 11, 420, 0, 1, 2, 3] | |
# 4 [4, 5, 6, 7, 8, 9, 420, 0] | |
# 5 [1, 2, 3, 4, 5, 6, 7, 8] | |
# 6 [9, 10, 11, 420, 0, 1, 2, 3] | |
# 7 [4, 5, 6, 7, 8, 9, 10, 11] | |
# 8 [12, 13, 420, 0, 1, 2, 3, 4] | |
# 9 [5, 6, 420, 0, 1, 2, 3, 4] | |
# 10 [5, 6, 7, 8, 9, 10, 11, 12] | |
# 11 [420, 0, 1, 2, 3, 4, 5, 6] | |
# 12 [420, 0, 1, 2, 3, 4, 5, 6] | |
# 13 [4, 5, 6, 7, 8, 9, 10, 11] | |
# 14 [12, 13, 420, 0, 1, 2, 3, 4] | |
# 15 [5, 6, 7, 8, 9, 10, 11, 420] | |
# 16 [0, 1, 2, 3, 4, 5, 6, 7] | |
# 17 [8, 9, 420, 0, 1, 2, 3, 4] | |
# 18 [5, 6, 7, 8, 9, 10, 11, 420] | |
# 19 [0, 1, 2, 3, 4, 5, 6, 7] | |
# 20 [8, 9, 10, 11, 12, 13, 420, 0] | |
# 21 [1, 2, 3, 4, 5, 6, 420, 0] | |
# 22 [1, 2, 3, 4, 5, 6, 7, 8] | |
# 23 [9, 10, 11, 12, 420, 0, 1, 2] | |
# 24 [3, 4, 5, 6, 420, 0, 1, 2] | |
# 25 [3, 4, 5, 6, 7, 8, 9, 420] | |
# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment