Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active April 10, 2023 03:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thistleknot/2e0c9a6df83ddaf29d6d02a3206d8b9c to your computer and use it in GitHub Desktop.
Save thistleknot/2e0c9a6df83ddaf29d6d02a3206d8b9c to your computer and use it in GitHub Desktop.
Web QA DPR using HayStack
from datasets import load_dataset
#import streamlit as st
from tqdm import tqdm
from haystack import Pipeline
from haystack.document_stores import FAISSDocumentStore
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.nodes.retriever.dense import DensePassageRetriever, DPRQuestionEncoderTokenizerFast, DPRContextEncoderTokenizerFast
from haystack.nodes.reader import FARMReader
from haystack.pipelines import ExtractiveQAPipeline, GenerativeQAPipeline
from haystack.utils import print_answers
from haystack.nodes import RAGenerator
from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
import os
dataset = load_dataset("EleutherAI/wikitext_document_level",'wikitext-103-v1')
merged_data = {}
for partition in ['train', 'test', 'validation']:
if partition in dataset:
for i in range(len(dataset[partition])):
page = dataset[partition][i]['page']
title = page.split('=', 2)[1].strip()
text = page.split('=', 2)[2].strip()
merged_data[title] = text
article_names = list(merged_data.keys())
print(article_names[0])
print(merged_data[article_names[0]])
#from haystack.document_store.faiss import FAISSDocumentStore
# set the path to the SQLite database file
db_path = "sqlite:///document_store.db"
if(True):
document_store = FAISSDocumentStore(
faiss_index_factory_str="Flat",
return_embedding=True,
embedding_dim=768,
sql_url=db_path,
index="title",
progress_bar=False,
)
else:
document_store = InMemoryDocumentStore()
# Create a dictionary to store the documents
documents = []
# Write the documents to the document store
#document_store.write_documents(documents)
#print(document_store.get_all_documents())
document_store.delete_documents()
# Delete existing documents in documents store
# Initialize document store
# Add search snippets to document store
#for k in tqdm(list(merged_data.keys())[0:100]):
for k in tqdm(list(merged_data.keys())):
document_store.write_documents([{
"content": merged_data[k],
"meta": {
"title": k,
}
}])
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
#embed_title=True,
)
document_store.update_embeddings(retriever=retriever)
# Initialize RAG Generator
generator = RAGenerator(
model_name_or_path="facebook/rag-token-nq",
use_gpu=True,
max_length=50,
min_length=20,
#embed_title=True,
num_beams=5,
#tokenizer=BartTokenizerFast.from_pretrained('facebook/rag-token-nq'),
retriever=retriever
)
query = 'Who is Tina Fey?'
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
pipe = ExtractiveQAPipeline(reader, retriever)
prediction = pipe.run(
query, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
)
print_answers(prediction, details="minimum")
# Run pipeline
pipe = GenerativeQAPipeline(generator=generator, retriever=retriever)
res = pipe.run(query=query, params={"Generator": {"top_k": 5}, "Retriever": {"top_k": 5}})
print_answers(res, details="minimum")
for a in res['answers']:
docs = [document_store.get_document_by_id(d) for d in a.document_ids]
titles = [d.meta['title'] for d in docs]
print(a.score, titles, a.answer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment