Skip to content

Instantly share code, notes, and snippets.

@shamanez
Last active April 21, 2021 13:09
Show Gist options
  • Save shamanez/738621d78a703504b8f6bb3b8f2a4bc5 to your computer and use it in GitHub Desktop.
Save shamanez/738621d78a703504b8f6bb3b8f2a4bc5 to your computer and use it in GitHub Desktop.
#check training_step hook in pytorch-lightning for further details in the function.
def training_step(self, batch, batch_idx) -> Dict:
global isEmUpdateBusy #global variable used in the parrelle process of embeddings computation
if (self.trainer.global_rank==0): #we initialize the embedding computing parrele process only on master DDP.
if (not batch_idx==0) and (batch_idx%500==0): #We want our embeddings to get updated in every 500th step
######we can assign any number of free GPUs to update the embeddings (optional)##########
free_gpu_list=[]
nvmlInit()
deviceCount = nvmlDeviceGetCount()
gpu_order_list=[0,1,2,3,4,5,6,7,8,9] #your GPU list order
for i in range(deviceCount):
handle = nvmlDeviceGetHandleByIndex(i)
info = nvmlDeviceGetMemoryInfo(handle)
if info.used/1e+6 < 15:
position=gpu_order_list.index(i)
free_gpu_list.append("cuda:"+str(position))
if len(free_gpu_list)>=4:
has_free_gpus=True
else:
has_free_gpus=False
########### done with checking for free GPUs and selecting given number of GPUs########
if (not isEmUpdateBusy) and has_free_gpus:
ctx_encoder=self.trainer.model.module.module.model.rag.ctx_encoder
model_copy =type(ctx_encoder)(self.config_dpr) # get a new instance #this will be load in the CPU
model_copy.load_state_dict(ctx_encoder.state_dict()) # copy weights and stuff
############using multi-gpus#################################
model_copy.share_memory()
processes = []
if len(free_gpu_list)>4:
cuda_devices=random.sample(free_gpu_list, 4)#free_gpu_list[:4]
else:
cuda_devices=free_gpu_list
num_processes=len(cuda_devices)
kb_dataset = load_dataset("csv", data_files=[self.custom_config.csv_path], split="train", delimiter="\t", column_names=["title", "text"],cache_dir=c_dir)
kb_list=[kb_dataset.shard(n, i, contiguous=True) for i in range(num_processes)]
for rank in range(num_processes):
device=cuda_devices[rank]
p = multiprocessing.Process(target=embed_update, args=(model_copy,device,rank,kb_list[rank],))
processes.append(p)
for p in processes:
p.start()
isEmUpdateBusy = True
if isEmUpdateBusy and (not isAddIndexBusy) :
if ((not processes[0].is_alive()) and (not processes[1].is_alive()) and (not processes[2].is_alive()) and (not processes[3].is_alive())):
threadHandle_index=multiprocessing.Process(target=add_index,args=(self.config.passages_path,self.config.index_path,))
threadHandle_index.start()
isOtherThreadIndexBusy = True
isAddIndexBusy = True
if isOtherThreadIndexBusy:
if not threadHandle_index.is_alive():
saved_dataset_shards=[]
for address in data_shard_addressses:
saved_dataset_shards.append(load_from_disk(address))
concat=concatenate_datasets(saved_dataset_shards)
concat.save_to_disk(self.config.passages_path)
print("done saving the dataset to the passage_path")
#initializing the RAY workers with newly computed embeddings and index
self.trainer.model.module.module.model.rag.retriever.set_new_index()
self.trainer.model.module.module.model.rag.retriever.init_retrieval()
print("done loading the new index")
isEmUpdateBusy = False
isOtherThreadIndexBusy =False
isAddIndexBusy=False
self.trainer.accelerator_connector.accelerator.barrier("barrier")
def embed_update(ctx_encoder,device,process_num,data_shrad):
arrow_folder='data_'+str(process_num)
arrow_file_name= 'data_'+str(process_num)+'.arrow'
passages_path= 'ur_dir' +arrow_folder+'/'
#/home/gsir059/cache/
cache_file_name='ur_dir'+ arrow_file_name
ctx_encoder =ctx_encoder.to(device=device)
context_tokenizer=DPRContextEncoderTokenizerFast.from_pretrained('facebook/dpr-ctx_encoder-multiset-base')
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast,device) -> dict:
"""Compute the DPR embeddings of document passages"""
input_ids = ctx_tokenizer(
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
)["input_ids"]
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
return {"embeddings": embeddings.detach().cpu().numpy()}
new_features = Features(
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
) # optional, save as float32 instead of float64 to save space
dataset =data_shrad.map(
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=context_tokenizer,device=device),
batched=True,
batch_size=16,
features=new_features,
cache_file_name=cache_file_name,
load_from_cache_file=False
)
dataset.save_to_disk(passages_path)
os.remove(dataset.cache_files[0]['filename'])
def add_index(passage_path,index_path):
data_shard_list=[]
for index_path in enumarate(data_shards_path):
data_shard_list.append(load_from_disk(index_path))
concat=concatenate_datasets(data_shard_list)
faiss.omp_set_num_threads(70)
start_time = time.time()
index = faiss.index_factory(768, "IVF4096,Flat")
index.nprobe=128
concat.add_faiss_index("embeddings", custom_index=index,train_size=-1) #cannot queue this since this usues a c++
concat.get_index("embeddings").save(index_path_bucket)
print("--- %s sec ---" % (time.time() - start_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment