Last active
July 4, 2024 13:43
-
-
Save davidmezzetti/e7fa7abc3e0a5d15ac99a645a2cba690 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
system = ( | |
"You are an assistant for question-answering tasks. " | |
"Use the following pieces of retrieved context to answer " | |
"the question. If you don't know the answer, say that you " | |
"don't know. Use three sentences maximum and keep the " | |
"answer concise." | |
"\n\n" | |
"{context}" | |
) | |
# Input data | |
loader = DirectoryLoader("txtai", glob="**/*.pdf") | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, chunk_overlap=200, add_start_index=True | |
) | |
# Embedding model | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={"device": "cuda"} | |
) | |
# Create vector store | |
documents = splitter.split_documents(loader.load()) | |
store = FAISS.from_documents(documents, embeddings) | |
# Define RAG chain | |
retriever = store.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 3} | |
) | |
llm = HuggingFacePipeline.from_model_id( | |
task="text-generation", | |
model_id="TheBloke/Mistral-7B-OpenOrca-AWQ", | |
device=0, | |
pipeline_kwargs={"max_new_tokens": 2048} | |
) | |
# Input prompt | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system), | |
("human", "{input}"), | |
]) | |
# Create RAG Chain | |
qa = create_stuff_documents_chain(llm, prompt) | |
rag = create_retrieval_chain(retriever, qa) | |
# Run RAG | |
response = rag.invoke({"input": "What model does txtai recommend for image captioning?"}) | |
print(response["answer"].split("\n")[-1]) | |
################################### | |
from glob import glob | |
from txtai import Embeddings, LLM | |
from txtai.pipeline import Textractor | |
def prompt(question, context): | |
system = ( | |
"You are an assistant for question-answering tasks. " | |
"Use the following pieces of retrieved context to answer " | |
"the question. If you don't know the answer, say that you " | |
"don't know. Use three sentences maximum and keep the " | |
"answer concise." | |
"\n\n" | |
f"{context}" | |
) | |
return [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": question} | |
] | |
def stream(path): | |
# Stream paragraphs from PDFs | |
for paragraphs in textractor(glob(f"{path}/*.pdf")): | |
yield from paragraphs | |
def rag(question): | |
# Builds context, prompt and runs RAG | |
context = "\n".join(x["text"] for x in embeddings.search(question)) | |
return llm(prompt(question, context), maxlength=2048) | |
# Text extractor formats - https://tika.apache.org/2.9.2/formats.html | |
textractor = Textractor(paragraphs=True) | |
# Create vector store | |
embeddings = Embeddings(content=True) | |
embeddings.index(stream("txtai")) | |
# LLM - supports Hugging Face models, llama.cpp GGUF and APIs (OpenAI, Ollama) | |
llm = LLM("TheBloke/Mistral-7B-OpenOrca-AWQ") | |
# Run RAG | |
print(rag("What model does txtai recommend for image captioning?")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment