Last active
April 21, 2021 13:09
-
-
Save shamanez/738621d78a703504b8f6bb3b8f2a4bc5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#check training_step hook in pytorch-lightning for further details in the function. | |
def training_step(self, batch, batch_idx) -> Dict: | |
global isEmUpdateBusy #global variable used in the parrelle process of embeddings computation | |
if (self.trainer.global_rank==0): #we initialize the embedding computing parrele process only on master DDP. | |
if (not batch_idx==0) and (batch_idx%500==0): #We want our embeddings to get updated in every 500th step | |
######we can assign any number of free GPUs to update the embeddings (optional)########## | |
free_gpu_list=[] | |
nvmlInit() | |
deviceCount = nvmlDeviceGetCount() | |
gpu_order_list=[0,1,2,3,4,5,6,7,8,9] #your GPU list order | |
for i in range(deviceCount): | |
handle = nvmlDeviceGetHandleByIndex(i) | |
info = nvmlDeviceGetMemoryInfo(handle) | |
if info.used/1e+6 < 15: | |
position=gpu_order_list.index(i) | |
free_gpu_list.append("cuda:"+str(position)) | |
if len(free_gpu_list)>=4: | |
has_free_gpus=True | |
else: | |
has_free_gpus=False | |
########### done with checking for free GPUs and selecting given number of GPUs######## | |
if (not isEmUpdateBusy) and has_free_gpus: | |
ctx_encoder=self.trainer.model.module.module.model.rag.ctx_encoder | |
model_copy =type(ctx_encoder)(self.config_dpr) # get a new instance #this will be load in the CPU | |
model_copy.load_state_dict(ctx_encoder.state_dict()) # copy weights and stuff | |
############using multi-gpus################################# | |
model_copy.share_memory() | |
processes = [] | |
if len(free_gpu_list)>4: | |
cuda_devices=random.sample(free_gpu_list, 4)#free_gpu_list[:4] | |
else: | |
cuda_devices=free_gpu_list | |
num_processes=len(cuda_devices) | |
kb_dataset = load_dataset("csv", data_files=[self.custom_config.csv_path], split="train", delimiter="\t", column_names=["title", "text"],cache_dir=c_dir) | |
kb_list=[kb_dataset.shard(n, i, contiguous=True) for i in range(num_processes)] | |
for rank in range(num_processes): | |
device=cuda_devices[rank] | |
p = multiprocessing.Process(target=embed_update, args=(model_copy,device,rank,kb_list[rank],)) | |
processes.append(p) | |
for p in processes: | |
p.start() | |
isEmUpdateBusy = True | |
if isEmUpdateBusy and (not isAddIndexBusy) : | |
if ((not processes[0].is_alive()) and (not processes[1].is_alive()) and (not processes[2].is_alive()) and (not processes[3].is_alive())): | |
threadHandle_index=multiprocessing.Process(target=add_index,args=(self.config.passages_path,self.config.index_path,)) | |
threadHandle_index.start() | |
isOtherThreadIndexBusy = True | |
isAddIndexBusy = True | |
if isOtherThreadIndexBusy: | |
if not threadHandle_index.is_alive(): | |
saved_dataset_shards=[] | |
for address in data_shard_addressses: | |
saved_dataset_shards.append(load_from_disk(address)) | |
concat=concatenate_datasets(saved_dataset_shards) | |
concat.save_to_disk(self.config.passages_path) | |
print("done saving the dataset to the passage_path") | |
#initializing the RAY workers with newly computed embeddings and index | |
self.trainer.model.module.module.model.rag.retriever.set_new_index() | |
self.trainer.model.module.module.model.rag.retriever.init_retrieval() | |
print("done loading the new index") | |
isEmUpdateBusy = False | |
isOtherThreadIndexBusy =False | |
isAddIndexBusy=False | |
self.trainer.accelerator_connector.accelerator.barrier("barrier") | |
def embed_update(ctx_encoder,device,process_num,data_shrad): | |
arrow_folder='data_'+str(process_num) | |
arrow_file_name= 'data_'+str(process_num)+'.arrow' | |
passages_path= 'ur_dir' +arrow_folder+'/' | |
#/home/gsir059/cache/ | |
cache_file_name='ur_dir'+ arrow_file_name | |
ctx_encoder =ctx_encoder.to(device=device) | |
context_tokenizer=DPRContextEncoderTokenizerFast.from_pretrained('facebook/dpr-ctx_encoder-multiset-base') | |
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast,device) -> dict: | |
"""Compute the DPR embeddings of document passages""" | |
input_ids = ctx_tokenizer( | |
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt" | |
)["input_ids"] | |
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output | |
return {"embeddings": embeddings.detach().cpu().numpy()} | |
new_features = Features( | |
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))} | |
) # optional, save as float32 instead of float64 to save space | |
dataset =data_shrad.map( | |
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=context_tokenizer,device=device), | |
batched=True, | |
batch_size=16, | |
features=new_features, | |
cache_file_name=cache_file_name, | |
load_from_cache_file=False | |
) | |
dataset.save_to_disk(passages_path) | |
os.remove(dataset.cache_files[0]['filename']) | |
def add_index(passage_path,index_path): | |
data_shard_list=[] | |
for index_path in enumarate(data_shards_path): | |
data_shard_list.append(load_from_disk(index_path)) | |
concat=concatenate_datasets(data_shard_list) | |
faiss.omp_set_num_threads(70) | |
start_time = time.time() | |
index = faiss.index_factory(768, "IVF4096,Flat") | |
index.nprobe=128 | |
concat.add_faiss_index("embeddings", custom_index=index,train_size=-1) #cannot queue this since this usues a c++ | |
concat.get_index("embeddings").save(index_path_bucket) | |
print("--- %s sec ---" % (time.time() - start_time)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment