Web QA DPR using HayStack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
#import streamlit as st | |
from tqdm import tqdm | |
from haystack import Pipeline | |
from haystack.document_stores import FAISSDocumentStore | |
from haystack.document_stores.memory import InMemoryDocumentStore | |
from haystack.nodes.retriever.dense import DensePassageRetriever, DPRQuestionEncoderTokenizerFast, DPRContextEncoderTokenizerFast | |
from haystack.nodes.reader import FARMReader | |
from haystack.pipelines import ExtractiveQAPipeline, GenerativeQAPipeline | |
from haystack.utils import print_answers | |
from haystack.nodes import RAGenerator | |
from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever | |
import os | |
dataset = load_dataset("EleutherAI/wikitext_document_level",'wikitext-103-v1') | |
merged_data = {} | |
for partition in ['train', 'test', 'validation']: | |
if partition in dataset: | |
for i in range(len(dataset[partition])): | |
page = dataset[partition][i]['page'] | |
title = page.split('=', 2)[1].strip() | |
text = page.split('=', 2)[2].strip() | |
merged_data[title] = text | |
article_names = list(merged_data.keys()) | |
print(article_names[0]) | |
print(merged_data[article_names[0]]) | |
#from haystack.document_store.faiss import FAISSDocumentStore | |
# set the path to the SQLite database file | |
db_path = "sqlite:///document_store.db" | |
if(True): | |
document_store = FAISSDocumentStore( | |
faiss_index_factory_str="Flat", | |
return_embedding=True, | |
embedding_dim=768, | |
sql_url=db_path, | |
index="title", | |
progress_bar=False, | |
) | |
else: | |
document_store = InMemoryDocumentStore() | |
# Create a dictionary to store the documents | |
documents = [] | |
# Write the documents to the document store | |
#document_store.write_documents(documents) | |
#print(document_store.get_all_documents()) | |
document_store.delete_documents() | |
# Delete existing documents in documents store | |
# Initialize document store | |
# Add search snippets to document store | |
#for k in tqdm(list(merged_data.keys())[0:100]): | |
for k in tqdm(list(merged_data.keys())): | |
document_store.write_documents([{ | |
"content": merged_data[k], | |
"meta": { | |
"title": k, | |
} | |
}]) | |
retriever = DensePassageRetriever( | |
document_store=document_store, | |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base", | |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", | |
use_gpu=True, | |
#embed_title=True, | |
) | |
document_store.update_embeddings(retriever=retriever) | |
# Initialize RAG Generator | |
generator = RAGenerator( | |
model_name_or_path="facebook/rag-token-nq", | |
use_gpu=True, | |
max_length=50, | |
min_length=20, | |
#embed_title=True, | |
num_beams=5, | |
#tokenizer=BartTokenizerFast.from_pretrained('facebook/rag-token-nq'), | |
retriever=retriever | |
) | |
query = 'Who is Tina Fey?' | |
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True) | |
pipe = ExtractiveQAPipeline(reader, retriever) | |
prediction = pipe.run( | |
query, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}} | |
) | |
print_answers(prediction, details="minimum") | |
# Run pipeline | |
pipe = GenerativeQAPipeline(generator=generator, retriever=retriever) | |
res = pipe.run(query=query, params={"Generator": {"top_k": 5}, "Retriever": {"top_k": 5}}) | |
print_answers(res, details="minimum") | |
for a in res['answers']: | |
docs = [document_store.get_document_by_id(d) for d in a.document_ids] | |
titles = [d.meta['title'] for d in docs] | |
print(a.score, titles, a.answer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment