Skip to content

Instantly share code, notes, and snippets.

@samarv
Created May 5, 2023 19:21
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save samarv/94f94dad5e8b1c5d3d7533353b0850e9 to your computer and use it in GitHub Desktop.
Save samarv/94f94dad5e8b1c5d3d7533353b0850e9 to your computer and use it in GitHub Desktop.
from flask import Flask, jsonify, request
import requests
import PyPDF2
import tempfile
import pickle
import retrying
from langchain.llms import OpenAI
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate #Imports
import tiktoken
import hashlib
import os
import urllib.parse
app = Flask(__name__) #Initiate the app
# In a real application, you would want to use a more secure method for generating and storing tokens
TOKEN = ""
openai_api_key = "sk-"
k = 3
def remove_substring(url):
try:
url = url[:-4]
nameoffile = url.replace("https://", "")
return nameoffile
except:
return url
@retrying.retry(retry_on_exception=lambda x: isinstance(x, requests.exceptions.RequestException), stop_max_attempt_number=5, wait_exponential_multiplier=1000, wait_exponential_max=10000)
def generate_hash(url):
try:
# Encode URL as bytes
url_bytes = url.encode('utf-8')
# Generate hash object using SHA-256 algorithm
hash_object = hashlib.sha256(url_bytes)
# Convert hash object to hexadecimal string
hex_dig = hash_object.hexdigest()
# Split hash string into 4 parts of equal length
split_len = len(hex_dig) // 4
hash_parts = [
hex_dig[i:i + split_len] for i in range(0, len(hex_dig), split_len)
]
# Join hash parts with dashes and return as a string
return '-'.join(hash_parts)
except Exception as e:
print(f"Error generating hash: {e}")
@retrying.retry(retry_on_exception=lambda x: isinstance(x, requests.exceptions.RequestException), stop_max_attempt_number=5, wait_exponential_multiplier=1000, wait_exponential_max=10000)
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
# Decorator function to check if the request contains a valid token
def require_token(f):
def wrapper(*args, **kwargs):
token = request.headers.get('Authorization')
if not token or token != f'Token {TOKEN}':
return jsonify({"error": "Unauthorized"}), 401
return f(*args, **kwargs)
wrapper.__name__ = f.__name__
return wrapper
@app.route('/', methods=["GET"]) #Home page
@require_token
@retrying.retry(retry_on_exception=lambda x: isinstance(x, requests.exceptions.RequestException), stop_max_attempt_number=5, wait_exponential_multiplier=1000, wait_exponential_max=10000)
def home():
return "<h1>Home for 'insert api name here' api.</h1>"
template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES").
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
ALWAYS return a list of "SOURCES" part in your answer.
QUESTION: {question}
=========
{context}
=========
FINAL ANSWER:"""
PROMPT = PromptTemplate(template=template,
input_variables=["context", "question"])
@retrying.retry(retry_on_exception=lambda x: isinstance(
x, requests.exceptions.RequestException),
stop_max_attempt_number=5,
wait_exponential_multiplier=1000,
wait_exponential_max=10000)
def get_pdf_text(pdf_url):
with tempfile.NamedTemporaryFile(suffix=".pdf") as pdf_file:
response = requests.get(pdf_url)
pdf_file.write(response.content)
pdf_file.seek(0)
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page_number in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_number]
yield Document(page_content=page.extract_text(),
metadata={"source": page_number + 1})
def get_file_text(file_url):
_, ext = os.path.splitext(file_url)
if ext.lower() == ".pdf":
with tempfile.NamedTemporaryFile(suffix=".pdf") as pdf_file:
response = requests.get(file_url)
pdf_file.write(response.content)
pdf_file.seek(0)
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page_number in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_number]
yield Document(page_content=page.extract_text(),
metadata={"source": file_url, "page_number": page_number + 1})
elif ext.lower() == ".txt":
with tempfile.NamedTemporaryFile(suffix=".txt") as txt_file:
response = requests.get(file_url)
txt_file.write(response.content)
txt_file.seek(0)
text = txt_file.read()
yield Document(page_content=text,
metadata={"source": file_url})
else:
raise ValueError("Unsupported file type")
@app.route('/v1/api/pdf_to_answer', methods=["POST"]) #Home page
@require_token
@retrying.retry(retry_on_exception=lambda x: isinstance(
x, requests.exceptions.RequestException),
stop_max_attempt_number=5,
wait_exponential_multiplier=1000,
wait_exponential_max=10000)
def pdf_to_answer():
json_data = request.get_json()
question = json_data.get('question')
pdf_url = json_data.get('pdf_url')
file_name = generate_hash(pdf_url)
k = 4
if not all([question, pdf_url, file_name]):
return jsonify({"error": "Missing required parameters"}), 400
search_index = None
chain = load_qa_with_sources_chain(OpenAI(temperature=0,
openai_api_key=openai_api_key),
chain_type="stuff",
prompt=PROMPT)
try:
with open(file_name, "rb") as f:
search_index = pickle.load(f)
except FileNotFoundError:
source_docs = list(get_file_text(pdf_url))
source_chunks = []
splitter = CharacterTextSplitter(separator=" ",
chunk_size=1024,
chunk_overlap=0)
for source in source_docs:
for chunk in splitter.split_text(source.page_content):
source_chunks.append(
Document(page_content=chunk, metadata=source.metadata))
search_index = FAISS.from_documents(
source_chunks, OpenAIEmbeddings(openai_api_key=openai_api_key))
with open(file_name, "wb") as f:
pickle.dump(search_index, f)
input_documents = search_index.similarity_search(question, k=k)
questiontokensCount = num_tokens_from_string(str(input_documents)+ question, "gpt2")
while questiontokensCount > 3500:
k = k-1
input_documents = search_index.similarity_search(question, k=k)
questiontokensCount = num_tokens_from_string(str(input_documents)+ question, "gpt2")
stuffchain = chain(
{
"input_documents": input_documents,
"question": question,
},
return_only_outputs=True,
)["output_text"]
f.close()
return {"stuffchain": stuffchain, "source": str(pdf_url)}
app.run(host="0.0.0.0", port="8080") #run app
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment