Skip to content

Instantly share code, notes, and snippets.

@shamanez
Last active April 16, 2021 04:40
Show Gist options
  • Save shamanez/ce6c3dbf0df10b7f2b95a5d51e10f40b to your computer and use it in GitHub Desktop.
Save shamanez/ce6c3dbf0df10b7f2b95a5d51e10f40b to your computer and use it in GitHub Desktop.
def training_step(self, batch, batch_idx) -> Dict:
global stepCount
global isEmUpdateBusy
global isAddIndexBusy
global processes
global isOtherThreadIndexBusy
if (self.trainer.global_rank==0): #we initialize the embedding computing parrele process only on master DDP.
if (not batch_idx==0) and (batch_idx%500==0): #We want our embeddings to get updated in every 500th step
######we can assign any number of free GPUs to update the embeddings (optional)##########
##################### Above code snippet###################################
######################### code for embedding computation end ############################
if isEmUpdateBusy and (not isAddIndexBusy) :
if ((not processes[0].is_alive()) and (not processes[1].is_alive()) and (not processes[2].is_alive()) and (not processes[3].is_alive())):
threadHandle_index=multiprocessing.Process(target=add_index,args=(self.config.passages_path,self.config.index_path,))
threadHandle_index.start()
isOtherThreadIndexBusy = True
isAddIndexBusy = True
def add_index(passage_path,index_path):
data_shard_list=[]
for index_path in enumarate(data_shards_path):
data_shard_list.append(load_from_disk(index_path))
concat=concatenate_datasets(data_shard_list)
faiss.omp_set_num_threads(70)
start_time = time.time()
index = faiss.index_factory(768, "IVF4096,Flat")
index.nprobe=128
concat.add_faiss_index("embeddings", custom_index=index,train_size=-1) #cannot queue this since this usues a c++
concat.get_index("embeddings").save(index_path_bucket)
print("--- %s sec ---" % (time.time() - start_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment