Skip to content

Instantly share code, notes, and snippets.

@leolivier
Created October 8, 2023 18:11
Show Gist options
  • Save leolivier/f9b2996a5404b841009840fbd6d2345c to your computer and use it in GitHub Desktop.
Save leolivier/f9b2996a5404b841009840fbd6d2345c to your computer and use it in GitHub Desktop.
demoRAG.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"mount_file_id": "1H8AmzcB4bexq7AtUbcxS16olIHGY3-XS",
"authorship_tag": "ABX9TyMp8KUa9MPANxhlncPgDKUh",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/leolivier/f9b2996a5404b841009840fbd6d2345c/demorag.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"**Before starting for the 1rst time**:\n",
"1. link your google drive space\n",
"1. execute the mkdir command below\n",
"1. import your PDFs files inside the `/content/drive/MyDrive/Colab\\ Notebooks/demodataPDFs` folder"
],
"metadata": {
"id": "TBL92PUHbIGj"
}
},
{
"cell_type": "code",
"source": [
"!mkdir -p '/content/drive/MyDrive/Colab Notebooks/demodataPDFs' '/content/drive/MyDrive/Colab Notebooks/vectorstore/db_faiss'"
],
"metadata": {
"id": "D-sqFuXwS0Ph"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"1. Install needed python libraries"
],
"metadata": {
"id": "HzSao3JfSniI"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JVBFwqToRA6n"
},
"outputs": [],
"source": [
"!pip install langchain chainlit pypdf sentence_transformers ctransformers faiss-gpu"
]
},
{
"cell_type": "markdown",
"source": [
"2. function for loading documents in the vector *DB*"
],
"metadata": {
"id": "RMbZPm-wcLn4"
}
},
{
"cell_type": "code",
"source": [
"from langchain.embeddings import HuggingFaceEmbeddings\n",
"from langchain.vectorstores import FAISS\n",
"from langchain.document_loaders import PyPDFLoader, DirectoryLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"DEVICE = 'cpu'\n",
"#DEVICE = 'cuda'\n",
"TRANSFORMER_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'\n",
"\n",
"DATA_PATH = '/content/drive/MyDrive/Colab Notebooks/demodataPDFs/'\n",
"DB_FAISS_PATH = '/content/drive/MyDrive/Colab Notebooks/vectorstore/db_faiss'\n",
"\n",
"# Create vector database\n",
"def create_vector_db():\n",
" loader = DirectoryLoader(DATA_PATH,\n",
" glob='*.pdf',\n",
" loader_cls=PyPDFLoader)\n",
" documents = loader.load()\n",
" text_splitter = RecursiveCharacterTextSplitter(\n",
" chunk_size=500,\n",
" chunk_overlap=50)\n",
" texts = text_splitter.split_documents(documents)\n",
" embeddings = HuggingFaceEmbeddings(\n",
" model_name=TRANSFORMER_MODEL,\n",
" model_kwargs={'device': DEVICE})\n",
" db = FAISS.from_documents(texts, embeddings)\n",
" db.save_local(DB_FAISS_PATH)\n",
"\n"
],
"metadata": {
"id": "KKivuYcvRqIR"
},
"execution_count": 38,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"**Check: did you import your PDF files as described above?**\n",
"\n",
"Ok, so\n",
"\n",
"3. load documents in vector database (can be done only once as long as you don't add new documents)"
],
"metadata": {
"id": "55jmkn1h7g5N"
}
},
{
"cell_type": "code",
"source": [
"create_vector_db()"
],
"metadata": {
"id": "dEzY0mzd7leZ"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"4. create the demoRAG.py file containing all the python code"
],
"metadata": {
"id": "yUd-u8kVeMz-"
}
},
{
"cell_type": "code",
"source": [
"%%writefile demoRAG.py\n",
"from langchain.embeddings import HuggingFaceEmbeddings\n",
"from langchain.vectorstores import FAISS\n",
"from langchain.document_loaders import PyPDFLoader, DirectoryLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.embeddings import HuggingFaceEmbeddings\n",
"from langchain import PromptTemplate\n",
"from langchain.llms import CTransformers\n",
"from langchain.chains import RetrievalQA\n",
"import chainlit as cl\n",
"\n",
"#DEVICE = 'cpu'\n",
"DEVICE = 'cuda'\n",
"HUGGINGFACEHUB_API_TOKEN = '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'\n",
"#MODEL = 'meta/Llama-2-7B-Chat-GGML'\n",
"TRANSFORMER_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'\n",
"CHAT_MODEL = 'TheBloke/Llama-2-7B-Chat-GGML'\n",
"\n",
"DATA_PATH = '/content/drive/MyDrive/Colab Notebooks/demodataPDFs/'\n",
"DB_FAISS_PATH = '/content/drive/MyDrive/Colab Notebooks/vectorstore/db_faiss'\n",
"\n",
"# Create vector database\n",
"def create_vector_db():\n",
" loader = DirectoryLoader(DATA_PATH,\n",
" glob='*.pdf',\n",
" loader_cls=PyPDFLoader)\n",
" documents = loader.load()\n",
" text_splitter = RecursiveCharacterTextSplitter(\n",
" chunk_size=500,\n",
" chunk_overlap=50)\n",
" texts = text_splitter.split_documents(documents)\n",
" embeddings = HuggingFaceEmbeddings(\n",
" model_name=TRANSFORMER_MODEL,\n",
" model_kwargs={'device': DEVICE})\n",
" db = FAISS.from_documents(texts, embeddings)\n",
" db.save_local(DB_FAISS_PATH)\n",
"\n",
"def open_vector_db():\n",
" embeddings = HuggingFaceEmbeddings(\n",
" model_name=TRANSFORMER_MODEL,\n",
" model_kwargs={'device': DEVICE})\n",
" db = FAISS.load_local(DB_FAISS_PATH, embeddings)\n",
" return db\n",
"\n",
"def set_custom_prompt():\n",
" \"\"\"\n",
" Prompt template for QA retrieval for each vectorstore\n",
" \"\"\"\n",
" custom_prompt_template = '''Use the following pieces of information to answer the user’s question.\n",
"If you don’t know the answer, just say that you don’t know, don’t try to make up an answer.\n",
"\n",
"Context: {context}\n",
"Question: {question}\n",
"\n",
"Only return the helpful answer below and nothing else.\n",
"Helpful and Caring answer:\n",
"'''\n",
" prompt = PromptTemplate(template=custom_prompt_template,\n",
" input_variables=['context', 'question'])\n",
" return prompt\n",
"#Retrieval QA Chain\n",
"def retrieval_qa_chain(llm, prompt, db):\n",
" qa_chain = RetrievalQA.from_chain_type(llm=llm,\n",
" chain_type='stuff',\n",
" retriever=db.as_retriever(search_kwargs={'k': 2}),\n",
" return_source_documents=True,\n",
" chain_type_kwargs={'prompt': prompt}\n",
" )\n",
" return qa_chain\n",
"#Loading the model\n",
"\n",
"def load_llm():\n",
" # Load the locally downloaded model here\n",
" gpu_layers = 100 if DEVICE == 'cuda' else 0\n",
" llm = CTransformers(\n",
" model = CHAT_MODEL,\n",
" model_type=\"llama\",\n",
" gpu_layers = gpu_layers,\n",
" max_new_tokens = 512,\n",
" temperature = 0.5\n",
" )\n",
" return llm\n",
"#QA Model Function\n",
"def qa_bot():\n",
" db = open_vector_db()\n",
" llm = load_llm()\n",
" qa_prompt = set_custom_prompt()\n",
" qa = retrieval_qa_chain(llm, qa_prompt, db)\n",
" return qa\n",
"#output function\n",
"def final_result(query):\n",
" qa_result = qa_bot()\n",
" response = qa_result({'query': query})\n",
" return response\n",
"\n",
"#chainlit code\n",
"@cl.on_chat_start\n",
"async def start():\n",
" chain = qa_bot()\n",
" msg = cl.Message(content=\"Starting your gen AI bot!...\")\n",
" await msg.send()\n",
" msg.content = \"Welcome to Demo Bot!. Ask your question here:\"\n",
" await msg.update()\n",
" cl.user_session.set(\"chain\", chain)\n",
"@cl.on_message\n",
"async def main(message):\n",
" chain = cl.user_session.get(\"chain\")\n",
" cb = cl.AsyncLangchainCallbackHandler(\n",
" stream_final_answer=True, answer_prefix_tokens=[\"FINAL\", \"ANSWER\"]\n",
" )\n",
" cb.answer_reached = True\n",
" res = await chain.acall(message, callbacks=[cb])\n",
" answer = res[\"result\"]\n",
" sources = res[\"documents\"]\n",
" if sources:\n",
" answer += f\"\\nSources:\" + str(sources)\n",
" else:\n",
" answer += \"\\nNo sources found\"\n",
" await cl.Message(content=answer).send()\n"
],
"metadata": {
"id": "BF9-N-9lJ_PL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Install localtunnel to be able to use chainlit (don't look at the warnings)\n",
"!npm install localtunnel"
],
"metadata": {
"id": "nmysGT6pNaVv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# use the IP printed by this code in the localtunnel window created in the next step\n",
"print(\"############\")\n",
"print(\"IP v\")\n",
"!curl https://ipv4.icanhazip.com/\n",
"print(\"############\")"
],
"metadata": {
"id": "ySLxnMHSTMJJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# booting server: click on the localtunnel URL then paste the above IP address\n",
"# in the field to be directed to the chainlit chatbot screen\n",
"!chainlit run demoRAG.py -w & npx localtunnel --port 8000"
],
"metadata": {
"id": "5S_ZUSwNK6OW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# stop chainlit\n",
"pid = !ps aux | grep \"chainlit\" | awk '{print $2}'\n",
"pid=pid[0]\n",
"print(pid)\n",
"!id\n",
"!kill -9 $pid\n",
"!ps aux | grep \"chainlit\""
],
"metadata": {
"id": "a-5NwHb0OMOD"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment