Skip to content

Instantly share code, notes, and snippets.

@harsh-99
Created May 7, 2021 17:32
Show Gist options
  • Save harsh-99/423bf7477ac3774ae74224aaf4aa4d7a to your computer and use it in GitHub Desktop.
Save harsh-99/423bf7477ac3774ae74224aaf4aa4d7a to your computer and use it in GitHub Desktop.
def create_bin(text, bin_size):
max_len = max(text)
min_len = min(text)
bin = {}
current = min_len+bin_size-1
while(current<max_len):
bin[current] = []
current = current + bin_size
bin[max_len] = []
current_index = 0
while(True):
dict_index = (((text[current_index]-min_len)//bin_size) + 1)*bin_size + min_len-1
bin[min(dict_index, max_len)].append(current_index)
current_index += 1
if(current_index>=len(text)):
break
return bin
class Sampler(torch.utils.data.Sampler):
def __init__(self, n_tokens, data, bin_size):
self.n_tokens = n_tokens
self.bin_size = bin_size
self.text_len = [len(gensim.utils.simple_preprocess(line)) for line in data]
self.bins_normal = create_bin(self.text_len, self.bin_size)
def __iter__(self):
bins = deepcopy(self.bins_normal)
for key in bins:
random.shuffle(bins[key])
final_indices = []
total_token = 0
index_current = 0
final_indices.append([])
counter = 0
for key in sorted(bins.keys(), reverse=True):
for index in bins[key]:
if(total_token+key > self.n_tokens):
total_token = 0
final_indices.append([])
index_current += 1
value_token = key
if(counter == 0):
value_token = key
counter+=1
total_token += value_token
final_indices[index_current].append(index)
random.shuffle(final_indices)
return iter(final_indices)
@kaikai23
Copy link

This is helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment