Skip to content

Instantly share code, notes, and snippets.

@shamanez
Last active April 21, 2021 12:45
Show Gist options
  • Save shamanez/6203995d4c597b740c49b81c72d3bae7 to your computer and use it in GitHub Desktop.
Save shamanez/6203995d4c597b740c49b81c72d3bae7 to your computer and use it in GitHub Desktop.
import logging
import random
import ray
from transformers import RagConfig, RagRetriever, RagTokenizer
from transformers.file_utils import requires_datasets, requires_faiss
from transformers.models.rag.retrieval_rag import CustomHFIndex
from transformers import (
DPRContextEncoderTokenizerFast)
logger = logging.getLogger(__name__)
class RayRetriever:
def __init__(self):
self.initialized = False
def create_rag_retriever(self, config, question_encoder_tokenizer, ctx_encoder_tokenizer,generator_tokenizer, index):
if not self.initialized:
self.retriever = RagRetriever(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
ctx_encoder_tokenizer=ctx_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.initialized = True
def init_retrieval(self):
self.retriever.index.init_index()
def clear_object():
#this is to set the index with the new one
del self.retriever
self.initialized = False
def retrieve(self, question_hidden_states, n_docs):
#print(self.retriever.index)
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
#newly added
#make sure that index class always uses the updated one
doc_dicts= self.retriever.index.get_doc_dicts(doc_ids)
return doc_ids, retrieved_doc_embeds,doc_dicts
class RagRayDistributedRetriever(RagRetriever):
"""
A distributed retriever built on top of the ``Ray`` API, a library
for building distributed applications (https://docs.ray.io/en/master/).
package. During training, all training workers initialize their own
instance of a `RagRayDistributedRetriever`, and each instance of
this distributed retriever shares a common set of Retrieval Ray
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
-classes-actors) that load the index on separate processes. Ray
handles the communication between the `RagRayDistributedRetriever`
instances and the remote Ray actors. If training is done in a
non-distributed setup, the index will simply be loaded in the same
process as the training worker and Ray will not be used.
Args:
config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
These actor classes run on remote processes and are responsible for performing the index lookup.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
"""
def __init__(self, config, question_encoder_tokenizer,ctx_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
raise ValueError(
"When using Ray for distributed fine-tuning, "
"you'll need to provide the paths instead, "
"as the dataset and the index are loaded "
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
)
super().__init__(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
ctx_encoder_tokenizer=ctx_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.retrieval_workers = retrieval_workers
if len(self.retrieval_workers) > 0:
ray.get(
[
worker.create_rag_retriever.remote(config, question_encoder_tokenizer,ctx_encoder_tokenizer, generator_tokenizer, index)
for worker in self.retrieval_workers
]
)
def init_retrieval(self):
"""
Retriever initialization function, needs to be called from the
training process. This function triggers retrieval initialization
for all retrieval actors if using distributed setting, or loads
index into current process if training is not distributed.
"""
logger.info("initializing retrieval")
if len(self.retrieval_workers) > 0:
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
else:
# Non-distributed training. Load index into this same process.
self.index.init_index()
def retrieve(self, question_hidden_states, n_docs):
"""
Retrieves documents for specified ``question_hidden_states``. If
running training with multiple workers, a random retrieval actor is
selected to perform the index lookup and return the result.
Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (:obj:`int`):
The number of docs retrieved per query.
Output:
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
The retrieval embeddings of the retrieved docs per query.
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
The ids of the documents in the index
doc_dicts (:obj:`List[dict]`):
The retrieved_doc_embeds examples per query.
"""
#print("I am trying to retrive with the workers")
if len(self.retrieval_workers) > 0:
# Select a random retrieval actor.
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
#newly added
doc_ids, retrieved_doc_embeds,doc_dicts = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs))
else:
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, doc_dicts#self.index.get_doc_dicts(doc_ids)
@classmethod
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)
@classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
requires_datasets(cls)
requires_faiss(cls)
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder
generator_tokenizer = rag_tokenizer.generator
ctx_encoder_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained('facebook/dpr-ctx_encoder-multiset-base')
if indexed_dataset is not None:
config.index_name = "custom"
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
else:
index = cls._build_index(config) #check this build index and lpald index thing
logger.info("done loading the dataset from the passage path")
return cls(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
ctx_encoder_tokenizer = ctx_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
retrieval_workers=actor_handles,
index=index,
)
def set_new_index(self):
logger.info("setting the new index for two GPUs")
#accesed from the training loop
#build the index object again
index =self._build_index(self.config) #the new index being loaded (raw form) #use the class since we are not initalizing the super class
#assign the new index
ray.get([worker.set_index.remote(index) for worker in self.retrieval_workers])
#Ray workers can be bit slow. before the init (load_faisis) workes mightbe accesed to retrive info.
#So it might be better to not to set the index
def re_load(self):
logger.info("re-loading the new dataset with embeddings")
#accesed from the training loop
#build the index object again
index =self._build_index(self.config) #the new index being loaded (raw form) #use the class since we are not initalizing the super class
#assign the new index
ray.get([worker.set_index.remote(index) for worker in self.retrieval_workers])
#Ray workers can be bit slow. before the init (load_faisis) workes mightbe accesed to retrive info.
#So it might be better to not to set the index
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment