Skip to content

Instantly share code, notes, and snippets.

@yashbonde
Created May 6, 2021 21:56
Show Gist options
  • Save yashbonde/c22a1fd5026e27f587801766d70ae6ef to your computer and use it in GitHub Desktop.
Save yashbonde/c22a1fd5026e27f587801766d70ae6ef to your computer and use it in GitHub Desktop.
# @yashbonde
#
# In this quick script we are trying to solve sharding problem:
# often in very large datasets there is no way to tokenize everything and store
# them. Considering the CLM datasets we have a fixed dataset where each row
# has dynamic number of tokens. A dummy looks like follows:
#
# j n sequence (w/o EOT = 42)
# [0] [15] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
# [1] [13] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
# [2] [11] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [3] [13] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
# [4] [15] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
# [5] [ 8] [0, 1, 2, 3, 4, 5, 6],
# [6] [14] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
# [7] [ 8] [0, 1, 2, 3, 4, 5, 6],
# [8] [11] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [9] [10] [0, 1, 2, 3, 4, 5, 6, 7, 8]
#
# During initialisation we provide
# a) seqlen: Size of each output sequence
# b) stride: Difference between two consecutive samples
#
# When training the model we train on continuous spans (size = seqlen)
# and these spans are obtained by merging multiple sequences or from the
# same sequence itself.
#
# - for seqlen = 10 and stride = 10
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 42, 0, 1, 2, 3, 4], ...
# stride = seqlen ensures there is no overlap in the sequences
#
# - for seqlen = 10, and stride = 5
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 42, 0, 1, 2, 3, 4], ...
# [5, 6, 7, 8, 9, 10, 11, 12, 13, 420], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ...
# notice how sequences have overlaps
#
# TASK: given a list of lists (see above) called `ds`, `seqlen` and `stride`
# a) can you find the total number of samples (l) in the dataset
# b) given any i <= l can you get me the correct sequence
#
# Code is given to you, however there is a bug and it is making me feel (ಥ﹏ಥ)
# can you solve this?
import numpy as np
# create a dataset mimicking the sentences dataset
ds = [np.arange(x).tolist() for x in np.random.randint(low = 5, high = 15, size = (10))]
lengths = np.array([len(x) + 1 for x in ds]) # plus 1 for <|endoftext|> token
length_cumsum = np.cumsum(lengths)
prin(ds, lengths, length_cumsum)
stride = 4
seqlen = 8
total_tokens = sum(lengths)
# A) Obtain the total number of samples
_len = 0
_start = 0
start_wise_count = [0]
while _start < seqlen:
# with each iteration shift the start by stride and so total tokens
# reduce by stride(s) multiples (handled as _start) and this shift
# continues while the total moves taken is < seqlen, because then
# you'll enter into the second patch and can overlap
# we also subtract incase the there
tokens_flat = total_tokens - _start
this_batch = tokens_flat // seqlen - int(tokens_flat % seqlen != 0)
_len += this_batch
_start += stride
start_wise_count.append(this_batch)
start_wise_count = np.cumsum(start_wise_count)
print(start_wise_count)
print("Total Samples:", _len)
# B) given `i` find the correct sequence
# in our representation indices are stride based, ie. if start_wise_counts = [13, 24]
# meaning first 13 samples belong to _start = 0 and next 11 belong to _start = stride
EOT = 420
for i in range(start_wise_count[-1]):
stride_mult = (start_wise_count > i).argmax() - 1 # this is stride multiplier
_start = stride * stride_mult # shift in start token to be added to token_id
stride_bucket_start_idx = i - start_wise_count[stride_mult] # this is the starting idx in that particular stride
token_start = stride_bucket_start_idx * seqlen + _start
input_ids = []
while len(input_ids) < seqlen:
# pick the required sequence from dataset, note add the EOT token
# this ds[seq_idx] is replaced by tokenizer.encode(text)
seq_idx = (length_cumsum > token_start).argmax()
seq = ds[seq_idx] + [EOT]
# but in general the starting point in the sequence is somewhere in the middle
# so we find that in seq_start
seq_start = len(seq) - (length_cumsum[seq_idx] - token_start)
# print("len_cumsum:", length_cumsum[seq_idx] , "| token_start:", token_start, "| seq_idx:", seq_idx, "| seq:", seq, "| len:", len(input_ids), "| seq_start:", seq_start)
seq_to_add = seq[seq_start:]
input_ids.extend(seq_to_add)
token_start += len(seq_to_add)
input_ids = input_ids[:seqlen]
print(i, "\t", input_ids)
# For the data given at top and code given below, we obtain the following sequences.
# However note that this is incorrect because after i = 12 there should have been
# 13 = [7, 8, 9, 420, 0, 1, 2, 3, 4, 5]
#
# i sequence
# 0 [0, 1, 2, 3, 4, 5, 6, 7]
# 1 [8, 9, 10, 11, 12, 13, 420, 0]
# 2 [1, 2, 3, 4, 5, 6, 7, 8]
# 3 [9, 10, 11, 420, 0, 1, 2, 3]
# 4 [4, 5, 6, 7, 8, 9, 420, 0]
# 5 [1, 2, 3, 4, 5, 6, 7, 8]
# 6 [9, 10, 11, 420, 0, 1, 2, 3]
# 7 [4, 5, 6, 7, 8, 9, 10, 11]
# 8 [12, 13, 420, 0, 1, 2, 3, 4]
# 9 [5, 6, 420, 0, 1, 2, 3, 4]
# 10 [5, 6, 7, 8, 9, 10, 11, 12]
# 11 [420, 0, 1, 2, 3, 4, 5, 6]
# 12 [420, 0, 1, 2, 3, 4, 5, 6]
# 13 [4, 5, 6, 7, 8, 9, 10, 11]
# 14 [12, 13, 420, 0, 1, 2, 3, 4]
# 15 [5, 6, 7, 8, 9, 10, 11, 420]
# 16 [0, 1, 2, 3, 4, 5, 6, 7]
# 17 [8, 9, 420, 0, 1, 2, 3, 4]
# 18 [5, 6, 7, 8, 9, 10, 11, 420]
# 19 [0, 1, 2, 3, 4, 5, 6, 7]
# 20 [8, 9, 10, 11, 12, 13, 420, 0]
# 21 [1, 2, 3, 4, 5, 6, 420, 0]
# 22 [1, 2, 3, 4, 5, 6, 7, 8]
# 23 [9, 10, 11, 12, 420, 0, 1, 2]
# 24 [3, 4, 5, 6, 420, 0, 1, 2]
# 25 [3, 4, 5, 6, 7, 8, 9, 420]
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment