Skip to content

Instantly share code, notes, and snippets.

@NirantK
Created July 20, 2023 22:11
Show Gist options
  • Save NirantK/540a9f489652437a6732aad0243765da to your computer and use it in GitHub Desktop.
Save NirantK/540a9f489652437a6732aad0243765da to your computer and use it in GitHub Desktop.
import codecs
import functools
import heapq
import io
import mimetypes
import operator
import random
import re
import subprocess
import tempfile
import threading
import typing
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import numpy as np
import openai
import pandas as pd
import pdftotext
import requests
import tiktoken
from furl import furl
from google.cloud import translate_v2
from googleapiclient import discovery
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload
from pydantic import BaseModel
# stolen from spacy https://spacy.io/api/sentencizer
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹', '।', '॥', '၊', '။', '።',
'፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄', '᥅', '᪨', '᪩', '᪪', '᪫',
'᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿', '‼', '‽', '⁇', '⁈', '⁉',
'⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶', '꡷', '꣎', '꣏', '꤯', '꧈',
'꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒', '﹖', '﹗', '!', '.', '?',
'𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀', '𑃁', '𑅁', '𑅂', '𑅃', '𑇅',
'𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼', '𑊩', '𑑋', '𑑌', '𑗂',
'𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐', '𑗑', '𑗒', '𑗓',
'𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂', '𑩃', '𑪛',
'𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈', '。', '。'] # fmt: skip
pad = r"\s*"
whitespace = pad + r"(\s)" + pad
line_break = r"([\r\n\f\v])"
new_line = pad + line_break + pad
new_para = pad + line_break + pad + line_break + pad
puncts = "".join(map(re.escape, default_punct_chars))
sentence_end = pad + r"([" + puncts + r"])"
default_separators = (
re.compile(sentence_end + new_para),
re.compile(new_para),
re.compile(sentence_end + new_line),
re.compile(sentence_end + whitespace),
re.compile(new_line),
re.compile(whitespace),
)
# from https://useragentstring.com/
FAKE_USER_AGENTS = [
# chrome
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.79 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.0.0 Safari/537.36",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/103.0.5060.53 Safari/537.36",
# edge
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19582",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19577",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/64.0.3282.140 Safari/537.36 Edge/18.17720",
"Mozilla/5.0 (Windows NT 10.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.8810.3391 Safari/537.36 Edge/18.14383",
]
T = typing.TypeVar("T")
R = typing.TypeVar("R")
L = typing.Callable[[str], int]
whitespace_re = re.compile(r"\s+")
FILENAME_WHITELIST = re.compile(r"[ a-zA-Z0-9\-_.]")
class DocMetadata(typing.NamedTuple):
name: str
etag: str | None
mime_type: str | None
def default_length_function(text: str) -> int:
local = threading.local()
try:
enc = local.gpt2enc
except AttributeError:
enc = tiktoken.get_encoding("gpt2")
local.gpt2enc = enc
return len(enc.encode(text))
class Document:
_length: int | None = None
def __init__(
self,
text: str,
span: tuple[int, int],
length_function: L = default_length_function,
):
self.text = text
self.span = span
self.start = self.span[0]
self.end = self.span[1]
self.length_function = length_function
def __len__(self):
if self._length is None:
self._length = self.length_function(self.text)
return self._length
def __add__(self, other):
return Document(
text=self.text + other.text,
span=(self.start, other.end),
length_function=self.length_function,
)
def __repr__(self):
return f"{self.__class__.__qualname__}(span={self.span!r}, text={self.text!r})"
class DocSearchRequest(BaseModel):
search_query: str
documents: list[str] | None
max_references: int | None
max_context_words: int | None
scroll_jump: int | None
selected_asr_model: str | None
google_translate_target: str | None
class SearchReference(typing.TypedDict):
url: str
title: str
snippet: str
score: float
def get_top_k_references(
request: DocSearchRequest,
) -> typing.Generator[str, None, list[SearchReference]]:
"""
Get the top k documents that ref the search query
Args:
request: the document search request
Returns:
the top k documents
"""
query_embeds = openai_embedding_create([request.search_query])[0]
input_docs = request.documents or []
embeds: list[tuple[SearchReference, np.ndarray]] = flatmap_parallel(
lambda f_url: doc_url_to_embeds(
f_url=f_url,
max_context_words=request.max_context_words,
scroll_jump=request.scroll_jump,
selected_asr_model=request.selected_asr_model,
google_translate_target=request.google_translate_target,
),
input_docs,
)
# get all matches above cutoff based on cosine similarity
cutoff = 0.7
candidates = [
{**ref, "score": score}
for ref, doc_embeds in embeds
if (score := query_embeds.dot(doc_embeds)) >= cutoff
]
# get top_k best matches
references = heapq.nlargest(
request.max_references, candidates, key=lambda match: match["score"]
)
# merge duplicate references
uniques = {}
for ref in references:
key = ref["url"]
try:
existing = uniques[key]
except KeyError:
uniques[key] = ref
else:
existing["snippet"] += "\n\n...\n\n" + ref["snippet"]
existing["score"] = (existing["score"] + ref["score"]) / 2
return list(uniques.values())
def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str:
"""
Convert a list of references to a prompt containing the formatted search results.
Args:
references: list of references
sep: separator between references in the prompt
Returns:
prompt string
"""
return sep.join(
f'''\
Search Result: [{idx + 1}]
Title: """{remove_quotes(ref["title"])}"""
Snippet: """
{remove_quotes(ref["snippet"])}
"""\
'''
for idx, ref in enumerate(references)
)
def doc_url_to_embeds(
*,
f_url: str,
max_context_words: int,
scroll_jump: int,
selected_asr_model: str = None,
google_translate_target: str = None,
) -> list[tuple[SearchReference, np.ndarray]]:
"""
Get document embeddings for a given document url.
Args:
f_url: document url
max_context_words: max number of words to include in each chunk
scroll_jump: number of words to scroll by
google_translate_target: target language for google translate
selected_asr_model: selected ASR model (used for audio files)
Returns:
list of (SearchReference, embeddings vector) tuples
"""
doc_meta = doc_url_to_metadata(f_url)
return get_embeds_for_doc(
f_url=f_url,
doc_meta=doc_meta,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
selected_asr_model=selected_asr_model,
google_translate_target=google_translate_target,
)
def remove_quotes(snippet: str) -> str:
return re.sub(r"[\"\']+", r'"', snippet).strip()
def doc_url_to_metadata(f_url: str) -> DocMetadata:
"""
Fetches the google drive metadata for a document url
Args:
f_url: document url
Returns:
document metadata
"""
f = furl(f_url.strip("/"))
if is_gdrive_url(f):
# extract filename from google drive metadata
try:
meta = gdrive_metadata(url_to_gdrive_file_id(f))
except HttpError as e:
if e.status_code == 404:
raise FileNotFoundError(
f"Could not download the google doc at {f_url} "
f"Please make sure to make the document public for viewing."
) from e
else:
raise
name = meta["name"]
etag = meta.get("md5Checksum") or meta.get("modifiedTime")
mime_type = meta["mimeType"]
# elif is_user_uploaded_url(str(f)):
# # extract filename from url
# name = f.path.segments[-1]
# etag = None
# mime_type = None
else:
# extract filename from url
name = f"{f.host}{f.path}"
etag = None
mime_type = None
return DocMetadata(name, etag, mime_type)
def get_embeds_for_doc(
*,
f_url: str,
doc_meta: DocMetadata,
max_context_words: int,
scroll_jump: int,
google_translate_target: str = None,
selected_asr_model: str = None,
) -> list[tuple[SearchReference, np.ndarray]]:
"""
Get document embeddings for a given document url.
Args:
f_url: document url
doc_meta: document metadata
max_context_words: max number of words to include in each chunk
scroll_jump: number of words to scroll by
google_translate_target: target language for google translate
selected_asr_model: selected ASR model (used for audio files)
Returns:
list of (metadata, embeddings) tuples
"""
pages = doc_url_to_text_pages(
f_url=f_url,
doc_meta=doc_meta,
selected_asr_model=selected_asr_model,
google_translate_target=google_translate_target,
)
chunk_size = int(max_context_words * 2)
chunk_overlap = int(max_context_words * 2 / scroll_jump)
metas: list[SearchReference]
# split the text into chunks
if isinstance(pages, pd.DataFrame):
metas = [
{
"title": doc_meta.name,
"url": f_url,
**row, # preserve extra csv rows
"score": -1,
"snippet": doc.text,
}
for idx, row in pages.iterrows()
for doc in text_splitter(
row["snippet"], chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
]
else:
metas = [
{
"title": doc_meta.name
+ (f" - Page {doc.end + 1}" if len(pages) > 1 else ""),
"url": furl(f_url)
.set(fragment_args={"page": doc.end + 1} if len(pages) > 1 else {})
.url,
"snippet": doc.text,
"score": -1,
}
for doc in text_splitter(
pages, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
]
# get doc embeds in batches
embeds = []
batch_size = 100
texts = [m["title"] + " | " + m["snippet"] for m in metas]
for i in range(0, len(texts), batch_size):
# progress = int(i / len(texts) * 100)
# print(f"Getting document embeddings ({progress}%)...")
batch = texts[i : i + batch_size]
embeds.extend(openai_embedding_create(batch))
return list(zip(metas, embeds))
def doc_url_to_text_pages(
*,
f_url: str,
doc_meta: DocMetadata,
google_translate_target: str | None,
selected_asr_model: str | None,
) -> list[str]:
"""
Download document from url and convert to text pages.
Args:
f_url: url of document
doc_meta: document metadata
google_translate_target: target language for google translate
selected_asr_model: selected ASR model (used for audio files)
Returns:
list of text pages
"""
f = furl(f_url)
f_name = doc_meta.name
if is_gdrive_url(f):
# download from google drive
f_bytes, ext = gdrive_download(f, doc_meta.mime_type)
else:
# download from url
try:
r = requests.get(
f_url,
headers={"User-Agent": random.choice(FAKE_USER_AGENTS)},
timeout=10,
)
r.raise_for_status()
except requests.RequestException as e:
print(f"ignore error while downloading {f_url}: {e}")
return []
f_bytes = r.content
# if it's a known encoding, standardize to utf-8
if r.encoding:
try:
codec = codecs.lookup(r.encoding)
except LookupError:
pass
else:
f_bytes = codec.decode(f_bytes)[0].encode()
ext = guess_ext_from_response(r)
# convert document to text pages
match ext:
case ".pdf":
pages = pdf_to_text_pages(io.BytesIO(f_bytes))
case ".docx" | ".md" | ".html" | ".rtf" | ".epub" | ".odt":
pages = [pandoc_to_text(f_name + ext, f_bytes)]
case ".txt":
pages = [f_bytes.decode()]
# case ".wav" | ".ogg" | ".mp3" | ".aac":
# if not selected_asr_model:
# raise ValueError(
# "For transcribing audio/video, please choose an ASR model from the settings!"
# )
# if is_gdrive_url(f):
# f_url = upload_file_from_bytes(
# f_name, f_bytes, content_type=doc_meta.mime_type
# )
# pages = [run_asr(f_url, selected_model=selected_asr_model, language="en")]
case ".csv" | ".xlsx" | ".tsv" | ".ods":
df = pd.read_csv(io.BytesIO(f_bytes), dtype=str).dropna()
assert (
"snippet" in df.columns
), f'uploaded spreadsheet must contain a "snippet" column - {f_name !r}'
pages = df
case _:
raise ValueError(f"Unsupported document format {ext!r} ({f_name})")
# optionally, translate text
if google_translate_target:
pages = run_google_translate(pages, google_translate_target)
return pages
def pdf_to_text_pages(f: typing.BinaryIO) -> list[str]:
return list(pdftotext.PDF(f))
def pandoc_to_text(f_name: str, f_bytes: bytes, to="plain") -> str:
"""
Convert document to text using pandoc.
Args:
f_name: filename of document
f_bytes: document bytes
to: pandoc output format (default: plain)
Returns:
extracted text content of document
"""
with (
tempfile.NamedTemporaryFile("wb", suffix="." + safe_filename(f_name)) as infile,
tempfile.NamedTemporaryFile("r") as outfile,
):
infile.write(f_bytes)
args = [
"pandoc",
"--standalone",
infile.name,
"--to",
to,
"--output",
outfile.name,
]
print("\t$ " + " ".join(args))
subprocess.check_call(args)
return outfile.read()
def text_splitter(
docs: typing.Iterable[str | Document],
*,
chunk_size: int,
chunk_overlap: int = 0,
separators: list[re.Pattern] = default_separators,
length_function: L = default_length_function,
) -> list[Document]:
if not docs:
return []
if isinstance(docs, str):
docs = [docs]
if isinstance(docs[0], str):
docs = [Document(d, (idx, idx), length_function) for idx, d in enumerate(docs)]
splits = _split(docs, chunk_size, separators)
docs = list(_join(splits, chunk_size, chunk_overlap))
return docs
def _split(
docs: list[Document],
chunk_size: int,
separators: list[re.Pattern],
) -> typing.Iterable[Document]:
if not separators:
raise ValueError("No separators left, cannot split further")
for doc in docs:
# skip empty docs
if not doc.text.strip():
continue
# if the doc is small enough, no need to split
if len(doc) <= chunk_size:
yield doc
continue
for text in re_split(separators[0], doc.text):
# skip empty fragments
if not text.strip():
continue
frag = Document(text, doc.span, doc.length_function)
# if the fragment is small enough, no need for further splitting
if len(frag) <= chunk_size:
yield frag
else:
yield from _split([frag], chunk_size, separators[1:])
def re_split(pat: re.Pattern, text: str):
"""Similar to re.split, but preserves the matched groups after splitting"""
last_match_end = 0
for match in pat.finditer(text):
end_char = "".join(match.groups())
frag = text[last_match_end : match.start()] + end_char
if frag:
yield frag
last_match_end = match.end()
yield text[last_match_end:]
def _join(
docs: typing.Iterable[Document],
chunk_size: int,
chunk_overlap: int,
) -> typing.Iterator[Document]:
window = deque()
window_len = 0
for doc in docs:
# grow window until largest possible chunk
if window_len + len(doc) <= chunk_size:
window.append(doc)
window_len += len(doc)
else:
# return the window until now
if window:
yield _merge(window)
# reset window
prev_window = window
window = deque([doc])
window_len = len(doc)
# add overlap from previous window
overlap_len = 0
for chunk in reversed(prev_window):
if (
# check if overlap is too large
overlap_len + len(chunk) > chunk_overlap
# check if window is too large
or window_len + len(chunk) > chunk_size
):
break
window.appendleft(chunk)
overlap_len += len(chunk)
window_len += len(chunk)
# return the leftover
if window:
yield _merge(window)
def _merge(docs: typing.Iterable[Document]) -> Document:
ret = functools.reduce(operator.add, docs)
return Document(
text=ret.text.strip(), # remove whitespace after merge
span=ret.span,
length_function=ret.length_function,
)
def guess_ext_from_response(response: requests.Response) -> str:
content_type = response.headers.get("Content-Type", "application/octet-stream")
mimetype = content_type.split(";")[0]
return mimetypes.guess_extension(mimetype) or ""
def safe_filename(filename: str) -> str:
matches = FILENAME_WHITELIST.finditer(filename)
filename = "".join(match.group(0) for match in matches)
p = Path(filename)
out = truncate_filename(p.stem) + p.suffix
return out
def truncate_filename(text: str, maxlen: int = 100, sep: str = "...") -> str:
if len(text) <= maxlen:
return text
assert len(sep) <= maxlen
mid = (maxlen - len(sep)) // 2
return text[:mid] + sep + text[-mid:]
def is_gdrive_url(f: furl) -> bool:
return f.host in ["drive.google.com", "docs.google.com"]
def gdrive_download(f: furl, mime_type: str) -> (bytes, str):
# get drive file id
file_id = url_to_gdrive_file_id(f)
# get metadata
service = discovery.build("drive", "v3")
# get files in drive directly
if f.host == "drive.google.com":
request = service.files().get_media(fileId=file_id)
ext = mimetypes.guess_extension(mime_type)
# export google docs to appropriate type
else:
mime_type, ext = docs_export_mimetype(f)
request = service.files().export_media(fileId=file_id, mimeType=mime_type)
# download
file = io.BytesIO()
downloader = MediaIoBaseDownload(file, request)
done = False
while done is False:
_, done = downloader.next_chunk()
# print(f"Download {int(status.progress() * 100)}%")
f_bytes = file.getvalue()
return f_bytes, ext
def url_to_gdrive_file_id(f: furl) -> str:
# extract google drive file ID
try:
# https://drive.google.com/u/0/uc?id=FILE_ID&...
file_id = f.query.params["id"]
except KeyError:
# https://drive.google.com/file/d/FILE_ID/...
# https://docs.google.com/document/d/FILE_ID/...
try:
file_id = f.path.segments[f.path.segments.index("d") + 1]
except (IndexError, ValueError):
raise ValueError(f"Bad google drive link: {str(f)!r}")
return file_id
def docs_export_mimetype(f: furl) -> tuple[str, str]:
"""
return the mimetype to export google docs - https://developers.google.com/drive/api/guides/ref-export-formats
Args:
f (furl): google docs link
Returns:
tuple[str, str]: (mime_type, extension)
"""
if "document" in f.path.segments:
mime_type = "text/plain"
ext = ".txt"
elif "spreadsheets" in f.path.segments:
mime_type = "text/csv"
ext = ".csv"
elif "presentation" in f.path.segments:
mime_type = "application/pdf"
ext = ".pdf"
elif "drawings" in f.path.segments:
mime_type = "application/pdf"
ext = ".pdf"
else:
raise ValueError(f"Not sure how to export google docs url: {str(f)!r}")
return mime_type, ext
def flatmap_parallel(
fn: typing.Callable[[T], list[R]], it: typing.Sequence[T]
) -> list[R]:
return flatten(map_parallel(fn, it))
def flatten(l1: typing.Iterable[typing.Iterable[T]]) -> list[T]:
return [it for l2 in l1 for it in l2]
def map_parallel(
fn: typing.Callable[[T], R], *iterables: typing.Sequence[T], max_workers: int = None
) -> list[R]:
assert iterables, "map_parallel() requires at least one iterable"
max_workers = max_workers or max(map(len, iterables))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
return list(pool.map(fn, *iterables))
def gdrive_metadata(file_id: str) -> dict:
service = discovery.build("drive", "v3")
metadata = (
service.files()
.get(
fileId=file_id,
fields="name,md5Checksum,modifiedTime,mimeType",
)
.execute()
)
return metadata
def run_google_translate(texts: list[str], google_translate_target: str) -> list[str]:
"""
Translate text using the Google Translate API.
Args:
texts (list[str]): Text to be translated.
google_translate_target (str): Language code to translate to.
Returns:
list[str]: Translated text.
"""
translate_client = translate_v2.Client()
result = translate_client.translate(
texts, target_language=google_translate_target, format_="text"
)
return [r["translatedText"] for r in result]
def openai_embedding_create(texts: list[str]) -> list[np.ndarray]:
# replace newlines, which can negatively affect performance.
texts = [whitespace_re.sub(" ", text) for text in texts]
res = openai.Embedding.create(model="text-embedding-ada-002", input=texts)
ret = np.array([data["embedding"] for data in res["data"]])
# see - https://community.openai.com/t/text-embedding-ada-002-embeddings-sometime-return-nan/279664/5
if np.isnan(ret).any():
raise RuntimeError("NaNs detected in embedding")
# raise openai.error.APIError("NaNs detected in embedding") # this lets us retry
expected = (len(texts), 1536)
if ret.shape != expected:
raise RuntimeError(
f"Unexpected shape for embedding: {ret.shape} (expected {expected})"
)
return ret
def calc_gpt_tokens(
text: str | list[str] | dict | list[dict],
*,
sep: str = "",
) -> int:
local = threading.local()
try:
enc = local.gpt2enc
except AttributeError:
enc = tiktoken.get_encoding("gpt2")
local.gpt2enc = enc
if isinstance(text, (str, dict)):
messages = [text]
else:
messages = text
combined = sep.join(
content
for entry in messages
if (
content := entry.get("content", "").strip()
if isinstance(entry, dict)
else str(entry)
)
)
return len(enc.encode(combined))
def convo_window_clipper(
window: list[dict],
max_tokens,
*,
sep: str = "",
step=2,
):
for i in range(len(window) - 2, -1, -step):
if calc_gpt_tokens(window[i:], sep=sep) > max_tokens:
return i + step
return 0
def build_prompt(
*,
search_query: str,
references: list[dict],
task_instructions: str,
system_prompt: str,
history_window: list[dict],
) -> str:
user_prompt = {"role": "user", "content": search_query}
if references:
user_prompt["content"] = (
references_as_prompt(references)
+ f"\n**********\n{task_instructions.strip()}\n**********\n"
+ user_prompt["content"]
)
# truncate the history to fit the model's max tokens
safety_buffer = 100
max_history_tokens = (
model_max_tokens
- calc_gpt_tokens([system_prompt, user_prompt])
- max_output_tokens
- safety_buffer
)
clip_idx = convo_window_clipper(history_window, max_history_tokens)
history_window = history_window[clip_idx:]
prompt_messages = [system_prompt, *history_window, user_prompt]
return prompt_messages
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment