Skip to content

Instantly share code, notes, and snippets.

@ZanSara
Last active October 15, 2023 13:52
Show Gist options
  • Save ZanSara/33020a980f2f535e2529df4ca4e8f08a to your computer and use it in GitHub Desktop.
Save ZanSara/33020a980f2f535e2529df4ca4e8f08a to your computer and use it in GitHub Desktop.
Translated Hybrid Retrieval
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L40ZxZW8lXQh"
},
"outputs": [],
"source": [
"# Install Haystack\n",
"\n",
"%%bash\n",
"\n",
"apt install libgraphviz-dev\n",
"pip install --upgrade pip\n",
"pip install pygraphviz\n",
"pip install datasets>=2.6.1\n",
"pip install farm-haystack[inference]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cLbh-UtelXRL"
},
"outputs": [],
"source": [
"# Grab a medical dataset\n",
"from datasets import load_dataset\n",
"\n",
"dataset = load_dataset(\"ywchoi/pubmed_abstract_3\", split=\"test\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RvrG_QzirSsq"
},
"outputs": [],
"source": [
"# Create the Documents\n",
"# The data has 3 features:\n",
"# * pmid\n",
"# * title\n",
"# * text\n",
"\n",
"from haystack.schema import Document\n",
"\n",
"documents = []\n",
"for doc in dataset:\n",
" documents.append(\n",
" Document(\n",
" content=doc[\"title\"] + \" \" + doc[\"text\"],\n",
" meta={\"title\": doc[\"title\"], \"abstract\": doc[\"text\"], \"pmid\": doc[\"pmid\"]},\n",
" )\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RrCCmLvGqhYw"
},
"outputs": [],
"source": [
"# Pre-process the documents to make them easier to search\n",
"\n",
"from haystack.nodes import PreProcessor\n",
"\n",
"preprocessor = PreProcessor(\n",
" clean_empty_lines=True,\n",
" clean_whitespace=True,\n",
" clean_header_footer=True,\n",
" split_by=\"word\",\n",
" split_length=512,\n",
" split_overlap=32,\n",
" split_respect_sentence_boundary=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8PzBU_jnsBTZ"
},
"outputs": [],
"source": [
"docs_to_index = preprocessor.process(documents)"
]
},
{
"cell_type": "code",
"source": [
"# Create the document store and write the docs in\n",
"from haystack.document_stores import InMemoryDocumentStore\n",
"\n",
"document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=384)\n",
"\n",
"document_store.write_documents(docs_to_index)"
],
"metadata": {
"id": "1bXhRFuHjE8V"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rXHbHru0lXRY"
},
"outputs": [],
"source": [
"# Create both Retrievers\n",
"\n",
"from haystack.nodes import EmbeddingRetriever, BM25Retriever\n",
"\n",
"sparse_retriever = BM25Retriever(document_store=document_store)\n",
"dense_retriever = EmbeddingRetriever(\n",
" document_store=document_store,\n",
" embedding_model=\"sentence-transformers/all-MiniLM-L6-v2\",\n",
" use_gpu=True,\n",
" scale_score=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7S-QdaDYlXRg"
},
"outputs": [],
"source": [
"# Update the embeddings for the dense retriever\n",
"\n",
"document_store.update_embeddings(retriever=dense_retriever)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d_RiKspTlXRl"
},
"outputs": [],
"source": [
"# Create the Ranker and the JoinDocuments nodes\n",
"\n",
"from haystack.nodes import JoinDocuments, SentenceTransformersRanker\n",
"\n",
"join_documents = JoinDocuments(join_mode=\"concatenate\")\n",
"ranker = SentenceTransformersRanker(model_name_or_path=\"cross-encoder/ms-marco-MiniLM-L-6-v2\")"
]
},
{
"cell_type": "code",
"source": [
"# Create the Reader\n",
"\n",
"from haystack.nodes import FARMReader\n",
"\n",
"model_name_or_path = \"deepset/roberta-base-squad2\"\n",
"reader = FARMReader(model_name_or_path, use_gpu=True)"
],
"metadata": {
"id": "0gkTonmVgfaD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Create the LanguageClassifier node (STUB NODE FOR NOW)\n",
"\n",
"from haystack.nodes.other import JoinNode\n",
"from haystack.nodes import BaseComponent, TransformersTranslator\n",
"\n",
"class LanguageClassifier(BaseComponent):\n",
"\n",
" outgoing_edges = 2\n",
"\n",
" def run(self, query):\n",
" if \"pour\" in query:\n",
" return {\"query\": query}, \"output_2\"\n",
" return {\"query\": query}, \"output_1\"\n",
"\n",
" def run_batch(self, queries):\n",
" pass\n",
"\n",
"language_classifier = LanguageClassifier()\n"
],
"metadata": {
"id": "rBH6Rk4As5Qo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# WORKAROUND NODES\n",
"\n",
"class TranslatorWorkaround(TransformersTranslator):\n",
"\n",
" outgoing_edges = 1\n",
"\n",
" def run(self, query):\n",
" results, edge = super().run(query=query)\n",
" return {**results, \"documents\": [] }, \"output_1\"\n",
"\n",
" def run_batch(self, queries):\n",
" pass\n",
"\n",
"\n",
"translator_workaround = TranslatorWorkaround(model_name_or_path=\"Helsinki-NLP/opus-mt-fr-en\")\n",
"# translator = TransformersTranslator(model_name_or_path=\"Helsinki-NLP/opus-mt-fr-en\")\n",
"\n",
"class JoinQueryWorkaround(JoinNode):\n",
"\n",
" def run_accumulated(self, inputs, *args, **kwargs):\n",
" return {\"query\": inputs[0].get(\"query\", None), \"documents\": inputs[1].get(\"documents\", None)}, \"output_1\"\n",
"\n",
" def run_batch_accumulated(self, inputs):\n",
" pass\n",
"\n",
"join_query_workaround = JoinQueryWorkaround()\n"
],
"metadata": {
"id": "4TMX3LeVhByK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i0XLbnAXlXRt"
},
"outputs": [],
"source": [
"# Assemble the Pipeline\n",
"\n",
"from haystack.pipelines import Pipeline\n",
"\n",
"pipeline = Pipeline()\n",
"pipeline.add_node(component=language_classifier, name=\"LanguageClassifier\", inputs=[\"Query\"])\n",
"pipeline.add_node(component=translator_workaround, name=\"TranslatorWorkaround\", inputs=[\"LanguageClassifier.output_2\"])\n",
"pipeline.add_node(component=sparse_retriever, name=\"SparseRetriever\", inputs=[\"LanguageClassifier.output_1\", \"TranslatorWorkaround\"])\n",
"pipeline.add_node(component=dense_retriever, name=\"DenseRetriever\", inputs=[\"LanguageClassifier.output_1\", \"TranslatorWorkaround\"])\n",
"pipeline.add_node(component=join_documents, name=\"JoinDocuments\", inputs=[\"SparseRetriever\", \"DenseRetriever\"])\n",
"pipeline.add_node(component=join_query_workaround, name=\"JoinQueryWorkaround\", inputs=[\"TranslatorWorkaround\", \"JoinDocuments\"])\n",
"pipeline.add_node(component=ranker, name=\"Ranker\", inputs=[\"JoinQueryWorkaround\"])\n",
"pipeline.add_node(component=reader, name=\"Reader\", inputs=[\"Ranker\"])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oCIMtwmThQG4"
},
"outputs": [],
"source": [
"# Draw the pipeline\n",
"pipeline.draw(\"pipeline_hybrid.png\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p-5WbeBulXR0"
},
"outputs": [],
"source": [
"# Run with a French query\n",
"\n",
"prediction = pipeline.run(\n",
" query=\"traitement pour le VIH\",\n",
" # debug=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mSUiizGNytwX"
},
"outputs": [],
"source": [
"# Inspect the results\n",
"\n",
"import pprint\n",
"\n",
"pprint.pprint(prediction)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment