Last active
October 15, 2023 13:52
-
-
Save ZanSara/33020a980f2f535e2529df4ca4e8f08a to your computer and use it in GitHub Desktop.
Translated Hybrid Retrieval
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "L40ZxZW8lXQh" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Install Haystack\n", | |
"\n", | |
"%%bash\n", | |
"\n", | |
"apt install libgraphviz-dev\n", | |
"pip install --upgrade pip\n", | |
"pip install pygraphviz\n", | |
"pip install datasets>=2.6.1\n", | |
"pip install farm-haystack[inference]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "cLbh-UtelXRL" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Grab a medical dataset\n", | |
"from datasets import load_dataset\n", | |
"\n", | |
"dataset = load_dataset(\"ywchoi/pubmed_abstract_3\", split=\"test\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "RvrG_QzirSsq" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Create the Documents\n", | |
"# The data has 3 features:\n", | |
"# * pmid\n", | |
"# * title\n", | |
"# * text\n", | |
"\n", | |
"from haystack.schema import Document\n", | |
"\n", | |
"documents = []\n", | |
"for doc in dataset:\n", | |
" documents.append(\n", | |
" Document(\n", | |
" content=doc[\"title\"] + \" \" + doc[\"text\"],\n", | |
" meta={\"title\": doc[\"title\"], \"abstract\": doc[\"text\"], \"pmid\": doc[\"pmid\"]},\n", | |
" )\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "RrCCmLvGqhYw" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Pre-process the documents to make them easier to search\n", | |
"\n", | |
"from haystack.nodes import PreProcessor\n", | |
"\n", | |
"preprocessor = PreProcessor(\n", | |
" clean_empty_lines=True,\n", | |
" clean_whitespace=True,\n", | |
" clean_header_footer=True,\n", | |
" split_by=\"word\",\n", | |
" split_length=512,\n", | |
" split_overlap=32,\n", | |
" split_respect_sentence_boundary=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "8PzBU_jnsBTZ" | |
}, | |
"outputs": [], | |
"source": [ | |
"docs_to_index = preprocessor.process(documents)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Create the document store and write the docs in\n", | |
"from haystack.document_stores import InMemoryDocumentStore\n", | |
"\n", | |
"document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=384)\n", | |
"\n", | |
"document_store.write_documents(docs_to_index)" | |
], | |
"metadata": { | |
"id": "1bXhRFuHjE8V" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "rXHbHru0lXRY" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Create both Retrievers\n", | |
"\n", | |
"from haystack.nodes import EmbeddingRetriever, BM25Retriever\n", | |
"\n", | |
"sparse_retriever = BM25Retriever(document_store=document_store)\n", | |
"dense_retriever = EmbeddingRetriever(\n", | |
" document_store=document_store,\n", | |
" embedding_model=\"sentence-transformers/all-MiniLM-L6-v2\",\n", | |
" use_gpu=True,\n", | |
" scale_score=False,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "7S-QdaDYlXRg" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Update the embeddings for the dense retriever\n", | |
"\n", | |
"document_store.update_embeddings(retriever=dense_retriever)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "d_RiKspTlXRl" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Create the Ranker and the JoinDocuments nodes\n", | |
"\n", | |
"from haystack.nodes import JoinDocuments, SentenceTransformersRanker\n", | |
"\n", | |
"join_documents = JoinDocuments(join_mode=\"concatenate\")\n", | |
"ranker = SentenceTransformersRanker(model_name_or_path=\"cross-encoder/ms-marco-MiniLM-L-6-v2\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Create the Reader\n", | |
"\n", | |
"from haystack.nodes import FARMReader\n", | |
"\n", | |
"model_name_or_path = \"deepset/roberta-base-squad2\"\n", | |
"reader = FARMReader(model_name_or_path, use_gpu=True)" | |
], | |
"metadata": { | |
"id": "0gkTonmVgfaD" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Create the LanguageClassifier node (STUB NODE FOR NOW)\n", | |
"\n", | |
"from haystack.nodes.other import JoinNode\n", | |
"from haystack.nodes import BaseComponent, TransformersTranslator\n", | |
"\n", | |
"class LanguageClassifier(BaseComponent):\n", | |
"\n", | |
" outgoing_edges = 2\n", | |
"\n", | |
" def run(self, query):\n", | |
" if \"pour\" in query:\n", | |
" return {\"query\": query}, \"output_2\"\n", | |
" return {\"query\": query}, \"output_1\"\n", | |
"\n", | |
" def run_batch(self, queries):\n", | |
" pass\n", | |
"\n", | |
"language_classifier = LanguageClassifier()\n" | |
], | |
"metadata": { | |
"id": "rBH6Rk4As5Qo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# WORKAROUND NODES\n", | |
"\n", | |
"class TranslatorWorkaround(TransformersTranslator):\n", | |
"\n", | |
" outgoing_edges = 1\n", | |
"\n", | |
" def run(self, query):\n", | |
" results, edge = super().run(query=query)\n", | |
" return {**results, \"documents\": [] }, \"output_1\"\n", | |
"\n", | |
" def run_batch(self, queries):\n", | |
" pass\n", | |
"\n", | |
"\n", | |
"translator_workaround = TranslatorWorkaround(model_name_or_path=\"Helsinki-NLP/opus-mt-fr-en\")\n", | |
"# translator = TransformersTranslator(model_name_or_path=\"Helsinki-NLP/opus-mt-fr-en\")\n", | |
"\n", | |
"class JoinQueryWorkaround(JoinNode):\n", | |
"\n", | |
" def run_accumulated(self, inputs, *args, **kwargs):\n", | |
" return {\"query\": inputs[0].get(\"query\", None), \"documents\": inputs[1].get(\"documents\", None)}, \"output_1\"\n", | |
"\n", | |
" def run_batch_accumulated(self, inputs):\n", | |
" pass\n", | |
"\n", | |
"join_query_workaround = JoinQueryWorkaround()\n" | |
], | |
"metadata": { | |
"id": "4TMX3LeVhByK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "i0XLbnAXlXRt" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Assemble the Pipeline\n", | |
"\n", | |
"from haystack.pipelines import Pipeline\n", | |
"\n", | |
"pipeline = Pipeline()\n", | |
"pipeline.add_node(component=language_classifier, name=\"LanguageClassifier\", inputs=[\"Query\"])\n", | |
"pipeline.add_node(component=translator_workaround, name=\"TranslatorWorkaround\", inputs=[\"LanguageClassifier.output_2\"])\n", | |
"pipeline.add_node(component=sparse_retriever, name=\"SparseRetriever\", inputs=[\"LanguageClassifier.output_1\", \"TranslatorWorkaround\"])\n", | |
"pipeline.add_node(component=dense_retriever, name=\"DenseRetriever\", inputs=[\"LanguageClassifier.output_1\", \"TranslatorWorkaround\"])\n", | |
"pipeline.add_node(component=join_documents, name=\"JoinDocuments\", inputs=[\"SparseRetriever\", \"DenseRetriever\"])\n", | |
"pipeline.add_node(component=join_query_workaround, name=\"JoinQueryWorkaround\", inputs=[\"TranslatorWorkaround\", \"JoinDocuments\"])\n", | |
"pipeline.add_node(component=ranker, name=\"Ranker\", inputs=[\"JoinQueryWorkaround\"])\n", | |
"pipeline.add_node(component=reader, name=\"Reader\", inputs=[\"Ranker\"])\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "oCIMtwmThQG4" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Draw the pipeline\n", | |
"pipeline.draw(\"pipeline_hybrid.png\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "p-5WbeBulXR0" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Run with a French query\n", | |
"\n", | |
"prediction = pipeline.run(\n", | |
" query=\"traitement pour le VIH\",\n", | |
" # debug=True\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "mSUiizGNytwX" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Inspect the results\n", | |
"\n", | |
"import pprint\n", | |
"\n", | |
"pprint.pprint(prediction)" | |
] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"gpuType": "T4", | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.4" | |
}, | |
"orig_nbformat": 4, | |
"vscode": { | |
"interpreter": { | |
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment