Skip to content

Instantly share code, notes, and snippets.

@arita37
Forked from mkeywood1/BRAG Example.ipynb
Created May 11, 2024 18:02
Show Gist options
  • Save arita37/40ed4706fb546eb2bfa545ad3e7f5527 to your computer and use it in GitHub Desktop.
Save arita37/40ed4706fb546eb2bfa545ad3e7f5527 to your computer and use it in GitHub Desktop.
Examples of using BRAG for Balanced Retrieval Augmented Generation
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e5d368fe",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import re\n",
"import openai\n",
"import pandas as pd\n",
"from llama_index.llms import AzureOpenAI\n",
"from llama_index.embeddings import OpenAIEmbedding\n",
"from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext\n",
"from llama_index import download_loader\n",
"from llama_index import StorageContext, load_index_from_storage\n",
"from llama_index import set_global_service_context"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9db46493-1f9b-4440-be13-d7a151cf90d3",
"metadata": {},
"outputs": [],
"source": [
"openai.api_type = \"azure\"\n",
"openai.api_base = os.environ['OPENAI_API_BASE']\n",
"openai.api_version = os.environ['OPENAI_API_VERSION']\n",
"openai.api_key = os.environ['OPENAI_API_KEY']\n",
"\n",
"llm = AzureOpenAI(\n",
" model=\"gpt-35-turbo\",\n",
" engine=\"gpt-35-turbo\",\n",
" api_type=\"azure\",\n",
" api_key=os.environ['OPENAI_API_KEY'],\n",
" api_base=os.environ['OPENAI_API_BASE'],\n",
" api_version=os.environ['OPENAI_API_VERSION']\n",
")\n",
"\n",
"# You need to deploy your own embedding model as well as your own chat completion model\n",
"embed_model = OpenAIEmbedding(\n",
" model=\"text-embedding-ada-002\",\n",
" deployment_name=\"text-embedding-ada-002\",\n",
" api_type=\"azure\",\n",
" api_key=os.environ['OPENAI_API_KEY'],\n",
" api_base=os.environ['OPENAI_API_BASE'],\n",
" api_version=os.environ['OPENAI_API_VERSION']\n",
")\n",
"\n",
"service_context = ServiceContext.from_defaults(\n",
" llm=llm,\n",
" embed_model=embed_model,\n",
")\n",
"set_global_service_context(service_context)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c19a99f1",
"metadata": {},
"outputs": [],
"source": [
"PDFReader = download_loader(\"PDFReader\")\n",
"loader = PDFReader()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4c204940-92fa-4c02-a399-917c7969e5ab",
"metadata": {},
"outputs": [],
"source": [
"# Load the document\n",
"documents = loader.load_data(file=\"2401.04088.pdf\")\n",
"\n",
"# Create the vectorstore\n",
"index = VectorStoreIndex.from_documents(documents)\n",
"\n",
"# Create the Query Engine\n",
"query_engine = index.as_query_engine()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9135eeba-7084-458d-9883-c043dc9fbebc",
"metadata": {},
"outputs": [],
"source": [
"def ask(question, context=\"\", system=\"\"):\n",
"\n",
" # Set default System Message\n",
" if system == \"\":\n",
" system = \"\"\"You are an expert in Artificial Intelligence Research Papers. \n",
"Use the following pieces of context to answer the users question. \n",
"If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
"\"\"\"\n",
"\n",
" # Prepend context if used\n",
" if context != \"\":\n",
" question = \"Use the following context to answer the users question:\\n```\\n\" + context + \"\\n```\\n\\n\" + question\n",
"\n",
" response = openai.ChatCompletion.create(\n",
" engine=\"gpt-35-turbo\",\n",
" messages = [{\"role\":\"system\",\"content\":system},{\"role\":\"user\",\"content\":question}],\n",
" temperature=0.0,\n",
" max_tokens=500,\n",
" top_p=0.95,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" stop=None)\n",
" \n",
" return response['choices'][0]['message']['content']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d93fdd98-1588-48a1-bed3-dc9fe902e39a",
"metadata": {},
"outputs": [],
"source": [
"def extract_section(documents, section_name, debug=False):\n",
" section_page = \"\"\n",
" section_text = \"\"\n",
"\n",
" for idx, page in enumerate(documents):\n",
" if section_text == \"\" and section_name in page.text.lower():\n",
" if debug: print(idx)\n",
"\n",
" context = page.text\n",
" if idx < len(documents)-2:\n",
" context += \"\\n\" + documents[idx+1].text\n",
" context += \"\\n\" + documents[idx+2].text\n",
"\n",
" answer = ask(f\"Does the above have the section called '{section_name}' or similar, and does it, in detail, explain the {section_name}?\", context)\n",
" if answer.startswith(\"Yes\"):\n",
" answer = ask(f\"\\n-----\\nWhat is the {section_name} in the document? Return everything in this section, up to the next heading. Do not interpret it, give me the verbatim text.\", context)\n",
" if debug: print(answer + \"\\n----------\")\n",
" section_page = idx + 1\n",
" section_text = answer\n",
" if debug: print(section_page, section_text, validate)\n",
"\n",
" return section_text, section_page"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "af5a0ab4-6157-4354-8604-93c7727ba71b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('The authors mentioned before the abstract are:\\n1. Albert Q. Jiang\\n2. Alexandre Sablayrolles\\n3. Antoine Roux\\n4. Arthur Mensch\\n5. Blanche Savary\\n6. Chris Bamford\\n7. Devendra Singh Chaplot\\n8. Diego de las Casas\\n9. Emma Bou Hanna\\n10. Florian Bressand\\n11. Gianna Lengyel\\n12. Guillaume Bour\\n13. Guillaume Lample\\n14. Lélio Renard Lavaud\\n15. Lucile Saulnier\\n16. Marie-Anne Lachaux\\n17. Pierre Stock\\n18. Sandeep Subramanian\\n19. Sophia Yang\\n20. Szymon Antoniak\\n21. Teven Le Scao\\n22. Théophile Gervet\\n23. Thibaut Lavril\\n24. Thomas Wang\\n25. Timothée Lacroix\\n26. William El Sayed',\n",
" 1)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sections = {}\n",
"\n",
"sections[\"authors\"] = (ask(\"Who are the authors mentioned before the abstract\", documents[0].text), 1)\n",
"sections[\"authors\"]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9e310758-5703-45b2-9f8f-9aa2cb0fdbd3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('We introduce Mixtral 8x7B, a Sparse Mixture of Experts (SMoE) language model. Mixtral has the same architecture as Mistral 7B, with the difference that each layer is composed of 8 feedforward blocks (i.e. experts). For every token, at each layer, a router network selects two experts to process the current state and combine their outputs. Even though each token only sees two experts, the selected experts can be different at each timestep. As a result, each token has access to 47B parameters, but only uses 13B active parameters during inference. Mixtral was trained with a context size of 32k tokens and it outperforms or matches Llama 2 70B and GPT-3.5 across all evaluated benchmarks. In particular, Mixtral vastly outperforms Llama 2 70B on mathematics, code generation, and multilingual benchmarks. We also provide a model fine-tuned to follow instructions, Mixtral 8x7B – Instruct, that surpasses GPT-3.5 Turbo, Claude-2.1, Gemini Pro, and Llama 2 70B – chat model on human benchmarks. Both the base and instruct models are released under the Apache 2.0 license.',\n",
" 1)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sections[\"abstract\"] = extract_section(documents, \"abstract\")\n",
"sections[\"abstract\"]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ea941a49-d534-42f0-8f40-3448f57d2483",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"What licenses are mentioned?\n",
"The licenses mentioned in the context information are not provided.\n",
"CPU times: total: 15.6 ms\n",
"Wall time: 1.79 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"query = 'What licenses are mentioned?'\n",
"print(query)\n",
"answer = query_engine.query(query)\n",
"print(answer.response)\n",
"find = re.findall(r\"'page_label': '[^']*'\", str(answer.metadata))\n",
"page_num = eval(\"{\" + find[0] + \"}\")[\"page_label\"]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "5b3a4a83-43a1-4f40-a268-62345c0e1b4a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: total: 15.6 ms\n",
"Wall time: 1.15 s\n"
]
},
{
"data": {
"text/plain": [
"'The licenses mentioned in the context are the Apache 2.0 license.'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"ask(query, sections[\"abstract\"][0])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "64242645-bd0b-4561-8c71-45a1a99eb390",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Is Jacob Austin an author of this paper?\n",
"Yes.\n",
"CPU times: total: 31.2 ms\n",
"Wall time: 1.42 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"query = 'Is Jacob Austin an author of this paper?'\n",
"print(query)\n",
"answer = query_engine.query(query)\n",
"print(answer.response)\n",
"find = re.findall(r\"'page_label': '[^']*'\", str(answer.metadata))\n",
"page_num = eval(\"{\" + find[0] + \"}\")[\"page_label\"]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a73a77d5-ed0b-4b7b-8323-4344b901e2c0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: total: 31.2 ms\n",
"Wall time: 1.11 s\n"
]
},
{
"data": {
"text/plain": [
"'No, Jacob Austin is not mentioned as an author of this paper.'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"ask(query, sections[\"authors\"][0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "753e59a6-6f79-44cb-bd38-3ccecc5a20d1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment