Created
July 20, 2023 22:11
-
-
Save NirantK/540a9f489652437a6732aad0243765da to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import codecs | |
import functools | |
import heapq | |
import io | |
import mimetypes | |
import operator | |
import random | |
import re | |
import subprocess | |
import tempfile | |
import threading | |
import typing | |
from collections import deque | |
from concurrent.futures import ThreadPoolExecutor | |
from pathlib import Path | |
import numpy as np | |
import openai | |
import pandas as pd | |
import pdftotext | |
import requests | |
import tiktoken | |
from furl import furl | |
from google.cloud import translate_v2 | |
from googleapiclient import discovery | |
from googleapiclient.errors import HttpError | |
from googleapiclient.http import MediaIoBaseDownload | |
from pydantic import BaseModel | |
# stolen from spacy https://spacy.io/api/sentencizer | |
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹', '।', '॥', '၊', '။', '።', | |
'፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄', '᥅', '᪨', '᪩', '᪪', '᪫', | |
'᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿', '‼', '‽', '⁇', '⁈', '⁉', | |
'⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶', '꡷', '꣎', '꣏', '꤯', '꧈', | |
'꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒', '﹖', '﹗', '!', '.', '?', | |
'𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀', '𑃁', '𑅁', '𑅂', '𑅃', '𑇅', | |
'𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼', '𑊩', '𑑋', '𑑌', '𑗂', | |
'𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐', '𑗑', '𑗒', '𑗓', | |
'𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂', '𑩃', '𑪛', | |
'𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈', '。', '。'] # fmt: skip | |
pad = r"\s*" | |
whitespace = pad + r"(\s)" + pad | |
line_break = r"([\r\n\f\v])" | |
new_line = pad + line_break + pad | |
new_para = pad + line_break + pad + line_break + pad | |
puncts = "".join(map(re.escape, default_punct_chars)) | |
sentence_end = pad + r"([" + puncts + r"])" | |
default_separators = ( | |
re.compile(sentence_end + new_para), | |
re.compile(new_para), | |
re.compile(sentence_end + new_line), | |
re.compile(sentence_end + whitespace), | |
re.compile(new_line), | |
re.compile(whitespace), | |
) | |
# from https://useragentstring.com/ | |
FAKE_USER_AGENTS = [ | |
# chrome | |
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.79 Safari/537.36", | |
"Mozilla/5.0 (Windows NT 10.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.0.0 Safari/537.36", | |
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.0.0 Safari/537.36", | |
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/103.0.5060.53 Safari/537.36", | |
# edge | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19582", | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19577", | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/64.0.3282.140 Safari/537.36 Edge/18.17720", | |
"Mozilla/5.0 (Windows NT 10.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.8810.3391 Safari/537.36 Edge/18.14383", | |
] | |
T = typing.TypeVar("T") | |
R = typing.TypeVar("R") | |
L = typing.Callable[[str], int] | |
whitespace_re = re.compile(r"\s+") | |
FILENAME_WHITELIST = re.compile(r"[ a-zA-Z0-9\-_.]") | |
class DocMetadata(typing.NamedTuple): | |
name: str | |
etag: str | None | |
mime_type: str | None | |
def default_length_function(text: str) -> int: | |
local = threading.local() | |
try: | |
enc = local.gpt2enc | |
except AttributeError: | |
enc = tiktoken.get_encoding("gpt2") | |
local.gpt2enc = enc | |
return len(enc.encode(text)) | |
class Document: | |
_length: int | None = None | |
def __init__( | |
self, | |
text: str, | |
span: tuple[int, int], | |
length_function: L = default_length_function, | |
): | |
self.text = text | |
self.span = span | |
self.start = self.span[0] | |
self.end = self.span[1] | |
self.length_function = length_function | |
def __len__(self): | |
if self._length is None: | |
self._length = self.length_function(self.text) | |
return self._length | |
def __add__(self, other): | |
return Document( | |
text=self.text + other.text, | |
span=(self.start, other.end), | |
length_function=self.length_function, | |
) | |
def __repr__(self): | |
return f"{self.__class__.__qualname__}(span={self.span!r}, text={self.text!r})" | |
class DocSearchRequest(BaseModel): | |
search_query: str | |
documents: list[str] | None | |
max_references: int | None | |
max_context_words: int | None | |
scroll_jump: int | None | |
selected_asr_model: str | None | |
google_translate_target: str | None | |
class SearchReference(typing.TypedDict): | |
url: str | |
title: str | |
snippet: str | |
score: float | |
def get_top_k_references( | |
request: DocSearchRequest, | |
) -> typing.Generator[str, None, list[SearchReference]]: | |
""" | |
Get the top k documents that ref the search query | |
Args: | |
request: the document search request | |
Returns: | |
the top k documents | |
""" | |
query_embeds = openai_embedding_create([request.search_query])[0] | |
input_docs = request.documents or [] | |
embeds: list[tuple[SearchReference, np.ndarray]] = flatmap_parallel( | |
lambda f_url: doc_url_to_embeds( | |
f_url=f_url, | |
max_context_words=request.max_context_words, | |
scroll_jump=request.scroll_jump, | |
selected_asr_model=request.selected_asr_model, | |
google_translate_target=request.google_translate_target, | |
), | |
input_docs, | |
) | |
# get all matches above cutoff based on cosine similarity | |
cutoff = 0.7 | |
candidates = [ | |
{**ref, "score": score} | |
for ref, doc_embeds in embeds | |
if (score := query_embeds.dot(doc_embeds)) >= cutoff | |
] | |
# get top_k best matches | |
references = heapq.nlargest( | |
request.max_references, candidates, key=lambda match: match["score"] | |
) | |
# merge duplicate references | |
uniques = {} | |
for ref in references: | |
key = ref["url"] | |
try: | |
existing = uniques[key] | |
except KeyError: | |
uniques[key] = ref | |
else: | |
existing["snippet"] += "\n\n...\n\n" + ref["snippet"] | |
existing["score"] = (existing["score"] + ref["score"]) / 2 | |
return list(uniques.values()) | |
def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str: | |
""" | |
Convert a list of references to a prompt containing the formatted search results. | |
Args: | |
references: list of references | |
sep: separator between references in the prompt | |
Returns: | |
prompt string | |
""" | |
return sep.join( | |
f'''\ | |
Search Result: [{idx + 1}] | |
Title: """{remove_quotes(ref["title"])}""" | |
Snippet: """ | |
{remove_quotes(ref["snippet"])} | |
"""\ | |
''' | |
for idx, ref in enumerate(references) | |
) | |
def doc_url_to_embeds( | |
*, | |
f_url: str, | |
max_context_words: int, | |
scroll_jump: int, | |
selected_asr_model: str = None, | |
google_translate_target: str = None, | |
) -> list[tuple[SearchReference, np.ndarray]]: | |
""" | |
Get document embeddings for a given document url. | |
Args: | |
f_url: document url | |
max_context_words: max number of words to include in each chunk | |
scroll_jump: number of words to scroll by | |
google_translate_target: target language for google translate | |
selected_asr_model: selected ASR model (used for audio files) | |
Returns: | |
list of (SearchReference, embeddings vector) tuples | |
""" | |
doc_meta = doc_url_to_metadata(f_url) | |
return get_embeds_for_doc( | |
f_url=f_url, | |
doc_meta=doc_meta, | |
max_context_words=max_context_words, | |
scroll_jump=scroll_jump, | |
selected_asr_model=selected_asr_model, | |
google_translate_target=google_translate_target, | |
) | |
def remove_quotes(snippet: str) -> str: | |
return re.sub(r"[\"\']+", r'"', snippet).strip() | |
def doc_url_to_metadata(f_url: str) -> DocMetadata: | |
""" | |
Fetches the google drive metadata for a document url | |
Args: | |
f_url: document url | |
Returns: | |
document metadata | |
""" | |
f = furl(f_url.strip("/")) | |
if is_gdrive_url(f): | |
# extract filename from google drive metadata | |
try: | |
meta = gdrive_metadata(url_to_gdrive_file_id(f)) | |
except HttpError as e: | |
if e.status_code == 404: | |
raise FileNotFoundError( | |
f"Could not download the google doc at {f_url} " | |
f"Please make sure to make the document public for viewing." | |
) from e | |
else: | |
raise | |
name = meta["name"] | |
etag = meta.get("md5Checksum") or meta.get("modifiedTime") | |
mime_type = meta["mimeType"] | |
# elif is_user_uploaded_url(str(f)): | |
# # extract filename from url | |
# name = f.path.segments[-1] | |
# etag = None | |
# mime_type = None | |
else: | |
# extract filename from url | |
name = f"{f.host}{f.path}" | |
etag = None | |
mime_type = None | |
return DocMetadata(name, etag, mime_type) | |
def get_embeds_for_doc( | |
*, | |
f_url: str, | |
doc_meta: DocMetadata, | |
max_context_words: int, | |
scroll_jump: int, | |
google_translate_target: str = None, | |
selected_asr_model: str = None, | |
) -> list[tuple[SearchReference, np.ndarray]]: | |
""" | |
Get document embeddings for a given document url. | |
Args: | |
f_url: document url | |
doc_meta: document metadata | |
max_context_words: max number of words to include in each chunk | |
scroll_jump: number of words to scroll by | |
google_translate_target: target language for google translate | |
selected_asr_model: selected ASR model (used for audio files) | |
Returns: | |
list of (metadata, embeddings) tuples | |
""" | |
pages = doc_url_to_text_pages( | |
f_url=f_url, | |
doc_meta=doc_meta, | |
selected_asr_model=selected_asr_model, | |
google_translate_target=google_translate_target, | |
) | |
chunk_size = int(max_context_words * 2) | |
chunk_overlap = int(max_context_words * 2 / scroll_jump) | |
metas: list[SearchReference] | |
# split the text into chunks | |
if isinstance(pages, pd.DataFrame): | |
metas = [ | |
{ | |
"title": doc_meta.name, | |
"url": f_url, | |
**row, # preserve extra csv rows | |
"score": -1, | |
"snippet": doc.text, | |
} | |
for idx, row in pages.iterrows() | |
for doc in text_splitter( | |
row["snippet"], chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
] | |
else: | |
metas = [ | |
{ | |
"title": doc_meta.name | |
+ (f" - Page {doc.end + 1}" if len(pages) > 1 else ""), | |
"url": furl(f_url) | |
.set(fragment_args={"page": doc.end + 1} if len(pages) > 1 else {}) | |
.url, | |
"snippet": doc.text, | |
"score": -1, | |
} | |
for doc in text_splitter( | |
pages, chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
] | |
# get doc embeds in batches | |
embeds = [] | |
batch_size = 100 | |
texts = [m["title"] + " | " + m["snippet"] for m in metas] | |
for i in range(0, len(texts), batch_size): | |
# progress = int(i / len(texts) * 100) | |
# print(f"Getting document embeddings ({progress}%)...") | |
batch = texts[i : i + batch_size] | |
embeds.extend(openai_embedding_create(batch)) | |
return list(zip(metas, embeds)) | |
def doc_url_to_text_pages( | |
*, | |
f_url: str, | |
doc_meta: DocMetadata, | |
google_translate_target: str | None, | |
selected_asr_model: str | None, | |
) -> list[str]: | |
""" | |
Download document from url and convert to text pages. | |
Args: | |
f_url: url of document | |
doc_meta: document metadata | |
google_translate_target: target language for google translate | |
selected_asr_model: selected ASR model (used for audio files) | |
Returns: | |
list of text pages | |
""" | |
f = furl(f_url) | |
f_name = doc_meta.name | |
if is_gdrive_url(f): | |
# download from google drive | |
f_bytes, ext = gdrive_download(f, doc_meta.mime_type) | |
else: | |
# download from url | |
try: | |
r = requests.get( | |
f_url, | |
headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, | |
timeout=10, | |
) | |
r.raise_for_status() | |
except requests.RequestException as e: | |
print(f"ignore error while downloading {f_url}: {e}") | |
return [] | |
f_bytes = r.content | |
# if it's a known encoding, standardize to utf-8 | |
if r.encoding: | |
try: | |
codec = codecs.lookup(r.encoding) | |
except LookupError: | |
pass | |
else: | |
f_bytes = codec.decode(f_bytes)[0].encode() | |
ext = guess_ext_from_response(r) | |
# convert document to text pages | |
match ext: | |
case ".pdf": | |
pages = pdf_to_text_pages(io.BytesIO(f_bytes)) | |
case ".docx" | ".md" | ".html" | ".rtf" | ".epub" | ".odt": | |
pages = [pandoc_to_text(f_name + ext, f_bytes)] | |
case ".txt": | |
pages = [f_bytes.decode()] | |
# case ".wav" | ".ogg" | ".mp3" | ".aac": | |
# if not selected_asr_model: | |
# raise ValueError( | |
# "For transcribing audio/video, please choose an ASR model from the settings!" | |
# ) | |
# if is_gdrive_url(f): | |
# f_url = upload_file_from_bytes( | |
# f_name, f_bytes, content_type=doc_meta.mime_type | |
# ) | |
# pages = [run_asr(f_url, selected_model=selected_asr_model, language="en")] | |
case ".csv" | ".xlsx" | ".tsv" | ".ods": | |
df = pd.read_csv(io.BytesIO(f_bytes), dtype=str).dropna() | |
assert ( | |
"snippet" in df.columns | |
), f'uploaded spreadsheet must contain a "snippet" column - {f_name !r}' | |
pages = df | |
case _: | |
raise ValueError(f"Unsupported document format {ext!r} ({f_name})") | |
# optionally, translate text | |
if google_translate_target: | |
pages = run_google_translate(pages, google_translate_target) | |
return pages | |
def pdf_to_text_pages(f: typing.BinaryIO) -> list[str]: | |
return list(pdftotext.PDF(f)) | |
def pandoc_to_text(f_name: str, f_bytes: bytes, to="plain") -> str: | |
""" | |
Convert document to text using pandoc. | |
Args: | |
f_name: filename of document | |
f_bytes: document bytes | |
to: pandoc output format (default: plain) | |
Returns: | |
extracted text content of document | |
""" | |
with ( | |
tempfile.NamedTemporaryFile("wb", suffix="." + safe_filename(f_name)) as infile, | |
tempfile.NamedTemporaryFile("r") as outfile, | |
): | |
infile.write(f_bytes) | |
args = [ | |
"pandoc", | |
"--standalone", | |
infile.name, | |
"--to", | |
to, | |
"--output", | |
outfile.name, | |
] | |
print("\t$ " + " ".join(args)) | |
subprocess.check_call(args) | |
return outfile.read() | |
def text_splitter( | |
docs: typing.Iterable[str | Document], | |
*, | |
chunk_size: int, | |
chunk_overlap: int = 0, | |
separators: list[re.Pattern] = default_separators, | |
length_function: L = default_length_function, | |
) -> list[Document]: | |
if not docs: | |
return [] | |
if isinstance(docs, str): | |
docs = [docs] | |
if isinstance(docs[0], str): | |
docs = [Document(d, (idx, idx), length_function) for idx, d in enumerate(docs)] | |
splits = _split(docs, chunk_size, separators) | |
docs = list(_join(splits, chunk_size, chunk_overlap)) | |
return docs | |
def _split( | |
docs: list[Document], | |
chunk_size: int, | |
separators: list[re.Pattern], | |
) -> typing.Iterable[Document]: | |
if not separators: | |
raise ValueError("No separators left, cannot split further") | |
for doc in docs: | |
# skip empty docs | |
if not doc.text.strip(): | |
continue | |
# if the doc is small enough, no need to split | |
if len(doc) <= chunk_size: | |
yield doc | |
continue | |
for text in re_split(separators[0], doc.text): | |
# skip empty fragments | |
if not text.strip(): | |
continue | |
frag = Document(text, doc.span, doc.length_function) | |
# if the fragment is small enough, no need for further splitting | |
if len(frag) <= chunk_size: | |
yield frag | |
else: | |
yield from _split([frag], chunk_size, separators[1:]) | |
def re_split(pat: re.Pattern, text: str): | |
"""Similar to re.split, but preserves the matched groups after splitting""" | |
last_match_end = 0 | |
for match in pat.finditer(text): | |
end_char = "".join(match.groups()) | |
frag = text[last_match_end : match.start()] + end_char | |
if frag: | |
yield frag | |
last_match_end = match.end() | |
yield text[last_match_end:] | |
def _join( | |
docs: typing.Iterable[Document], | |
chunk_size: int, | |
chunk_overlap: int, | |
) -> typing.Iterator[Document]: | |
window = deque() | |
window_len = 0 | |
for doc in docs: | |
# grow window until largest possible chunk | |
if window_len + len(doc) <= chunk_size: | |
window.append(doc) | |
window_len += len(doc) | |
else: | |
# return the window until now | |
if window: | |
yield _merge(window) | |
# reset window | |
prev_window = window | |
window = deque([doc]) | |
window_len = len(doc) | |
# add overlap from previous window | |
overlap_len = 0 | |
for chunk in reversed(prev_window): | |
if ( | |
# check if overlap is too large | |
overlap_len + len(chunk) > chunk_overlap | |
# check if window is too large | |
or window_len + len(chunk) > chunk_size | |
): | |
break | |
window.appendleft(chunk) | |
overlap_len += len(chunk) | |
window_len += len(chunk) | |
# return the leftover | |
if window: | |
yield _merge(window) | |
def _merge(docs: typing.Iterable[Document]) -> Document: | |
ret = functools.reduce(operator.add, docs) | |
return Document( | |
text=ret.text.strip(), # remove whitespace after merge | |
span=ret.span, | |
length_function=ret.length_function, | |
) | |
def guess_ext_from_response(response: requests.Response) -> str: | |
content_type = response.headers.get("Content-Type", "application/octet-stream") | |
mimetype = content_type.split(";")[0] | |
return mimetypes.guess_extension(mimetype) or "" | |
def safe_filename(filename: str) -> str: | |
matches = FILENAME_WHITELIST.finditer(filename) | |
filename = "".join(match.group(0) for match in matches) | |
p = Path(filename) | |
out = truncate_filename(p.stem) + p.suffix | |
return out | |
def truncate_filename(text: str, maxlen: int = 100, sep: str = "...") -> str: | |
if len(text) <= maxlen: | |
return text | |
assert len(sep) <= maxlen | |
mid = (maxlen - len(sep)) // 2 | |
return text[:mid] + sep + text[-mid:] | |
def is_gdrive_url(f: furl) -> bool: | |
return f.host in ["drive.google.com", "docs.google.com"] | |
def gdrive_download(f: furl, mime_type: str) -> (bytes, str): | |
# get drive file id | |
file_id = url_to_gdrive_file_id(f) | |
# get metadata | |
service = discovery.build("drive", "v3") | |
# get files in drive directly | |
if f.host == "drive.google.com": | |
request = service.files().get_media(fileId=file_id) | |
ext = mimetypes.guess_extension(mime_type) | |
# export google docs to appropriate type | |
else: | |
mime_type, ext = docs_export_mimetype(f) | |
request = service.files().export_media(fileId=file_id, mimeType=mime_type) | |
# download | |
file = io.BytesIO() | |
downloader = MediaIoBaseDownload(file, request) | |
done = False | |
while done is False: | |
_, done = downloader.next_chunk() | |
# print(f"Download {int(status.progress() * 100)}%") | |
f_bytes = file.getvalue() | |
return f_bytes, ext | |
def url_to_gdrive_file_id(f: furl) -> str: | |
# extract google drive file ID | |
try: | |
# https://drive.google.com/u/0/uc?id=FILE_ID&... | |
file_id = f.query.params["id"] | |
except KeyError: | |
# https://drive.google.com/file/d/FILE_ID/... | |
# https://docs.google.com/document/d/FILE_ID/... | |
try: | |
file_id = f.path.segments[f.path.segments.index("d") + 1] | |
except (IndexError, ValueError): | |
raise ValueError(f"Bad google drive link: {str(f)!r}") | |
return file_id | |
def docs_export_mimetype(f: furl) -> tuple[str, str]: | |
""" | |
return the mimetype to export google docs - https://developers.google.com/drive/api/guides/ref-export-formats | |
Args: | |
f (furl): google docs link | |
Returns: | |
tuple[str, str]: (mime_type, extension) | |
""" | |
if "document" in f.path.segments: | |
mime_type = "text/plain" | |
ext = ".txt" | |
elif "spreadsheets" in f.path.segments: | |
mime_type = "text/csv" | |
ext = ".csv" | |
elif "presentation" in f.path.segments: | |
mime_type = "application/pdf" | |
ext = ".pdf" | |
elif "drawings" in f.path.segments: | |
mime_type = "application/pdf" | |
ext = ".pdf" | |
else: | |
raise ValueError(f"Not sure how to export google docs url: {str(f)!r}") | |
return mime_type, ext | |
def flatmap_parallel( | |
fn: typing.Callable[[T], list[R]], it: typing.Sequence[T] | |
) -> list[R]: | |
return flatten(map_parallel(fn, it)) | |
def flatten(l1: typing.Iterable[typing.Iterable[T]]) -> list[T]: | |
return [it for l2 in l1 for it in l2] | |
def map_parallel( | |
fn: typing.Callable[[T], R], *iterables: typing.Sequence[T], max_workers: int = None | |
) -> list[R]: | |
assert iterables, "map_parallel() requires at least one iterable" | |
max_workers = max_workers or max(map(len, iterables)) | |
with ThreadPoolExecutor(max_workers=max_workers) as pool: | |
return list(pool.map(fn, *iterables)) | |
def gdrive_metadata(file_id: str) -> dict: | |
service = discovery.build("drive", "v3") | |
metadata = ( | |
service.files() | |
.get( | |
fileId=file_id, | |
fields="name,md5Checksum,modifiedTime,mimeType", | |
) | |
.execute() | |
) | |
return metadata | |
def run_google_translate(texts: list[str], google_translate_target: str) -> list[str]: | |
""" | |
Translate text using the Google Translate API. | |
Args: | |
texts (list[str]): Text to be translated. | |
google_translate_target (str): Language code to translate to. | |
Returns: | |
list[str]: Translated text. | |
""" | |
translate_client = translate_v2.Client() | |
result = translate_client.translate( | |
texts, target_language=google_translate_target, format_="text" | |
) | |
return [r["translatedText"] for r in result] | |
def openai_embedding_create(texts: list[str]) -> list[np.ndarray]: | |
# replace newlines, which can negatively affect performance. | |
texts = [whitespace_re.sub(" ", text) for text in texts] | |
res = openai.Embedding.create(model="text-embedding-ada-002", input=texts) | |
ret = np.array([data["embedding"] for data in res["data"]]) | |
# see - https://community.openai.com/t/text-embedding-ada-002-embeddings-sometime-return-nan/279664/5 | |
if np.isnan(ret).any(): | |
raise RuntimeError("NaNs detected in embedding") | |
# raise openai.error.APIError("NaNs detected in embedding") # this lets us retry | |
expected = (len(texts), 1536) | |
if ret.shape != expected: | |
raise RuntimeError( | |
f"Unexpected shape for embedding: {ret.shape} (expected {expected})" | |
) | |
return ret | |
def calc_gpt_tokens( | |
text: str | list[str] | dict | list[dict], | |
*, | |
sep: str = "", | |
) -> int: | |
local = threading.local() | |
try: | |
enc = local.gpt2enc | |
except AttributeError: | |
enc = tiktoken.get_encoding("gpt2") | |
local.gpt2enc = enc | |
if isinstance(text, (str, dict)): | |
messages = [text] | |
else: | |
messages = text | |
combined = sep.join( | |
content | |
for entry in messages | |
if ( | |
content := entry.get("content", "").strip() | |
if isinstance(entry, dict) | |
else str(entry) | |
) | |
) | |
return len(enc.encode(combined)) | |
def convo_window_clipper( | |
window: list[dict], | |
max_tokens, | |
*, | |
sep: str = "", | |
step=2, | |
): | |
for i in range(len(window) - 2, -1, -step): | |
if calc_gpt_tokens(window[i:], sep=sep) > max_tokens: | |
return i + step | |
return 0 | |
def build_prompt( | |
*, | |
search_query: str, | |
references: list[dict], | |
task_instructions: str, | |
system_prompt: str, | |
history_window: list[dict], | |
) -> str: | |
user_prompt = {"role": "user", "content": search_query} | |
if references: | |
user_prompt["content"] = ( | |
references_as_prompt(references) | |
+ f"\n**********\n{task_instructions.strip()}\n**********\n" | |
+ user_prompt["content"] | |
) | |
# truncate the history to fit the model's max tokens | |
safety_buffer = 100 | |
max_history_tokens = ( | |
model_max_tokens | |
- calc_gpt_tokens([system_prompt, user_prompt]) | |
- max_output_tokens | |
- safety_buffer | |
) | |
clip_idx = convo_window_clipper(history_window, max_history_tokens) | |
history_window = history_window[clip_idx:] | |
prompt_messages = [system_prompt, *history_window, user_prompt] | |
return prompt_messages |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment