Skip to content

Instantly share code, notes, and snippets.

@sirupsen
Last active March 5, 2023 14:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sirupsen/5f81105b73c09f8425e4a3b4bd70c473 to your computer and use it in GitHub Desktop.
Save sirupsen/5f81105b73c09f8425e4a3b4bd70c473 to your computer and use it in GitHub Desktop.
I run all my searches in FZF as I edit my notes in Vim. See https://share.cleanshot.com/x9ZkfBQQ -- Will require some tinkering but it runs the inference/search server so the searches are fast. It's deployed to Modal.com
#!/bin/bash
# FZF quotes field names so I can't easily get preview to work otherwise
cd "$ZK_PATH"
# --bind="2:preview(echo '\"{q}\",\"{4..}\",2' >> judge.csv && echo Rated 2!)" \
FZF_DEFAULT_COMMAND="python ~/src/semsearch/search_webhook.py learning"
fzf \
--bind="0:preview(ruby ~/src/semsearch/write.rb {q} {4..} 0 && echo Rated 0!)" \
--bind="1:preview(ruby ~/src/semsearch/write.rb {q} {4..} 1 && echo Rated 1!)" \
--bind="2:preview(ruby ~/src/semsearch/write.rb {q} {4..} 2 && echo Rated 2!)" \
--header="Press [0], [1], [2] to add an item to the judgement list with that rating" \
--prompt "Semantic > " \
--bind "change:reload(python ~/src/semsearch/search_webhook.py {q})+change-prompt(Semantic > )" \
--bind "tab:reload(textgrep --scores \"{q}\")+change-prompt(BM25 > )+unbind(change)" \
--bind "btab:reload(python ~/src/semsearch/search_webhook.py \"{q}\")+change-prompt(Semantic > )+rebind(change)" \
--disabled \
--ansi \
--with-nth '1,4..' \
--no-hscroll \
--preview-window 'top:65%,+{3}' \
--no-multi \
--height 100% \
--tac \
--query "learning" \
--preview "bat --language md --color always --plain {4..} --highlight-line {2}"
import glob
import multiprocessing
import os
import pathlib
import signal
import socket
import sys
import time
import modal
volume = modal.SharedVolume().persist("model-cache")
stub = modal.Stub(
"search-webhook",
image=modal.Image.debian_slim().pip_install(["sentence-transformers"]),
)
note_path = "~/Documents/Zettelkasten/**/*.md"
paragraphs = None
bi_encoder = None
cross_encoder = None
# write a Python function to check if a port is open on the machine
# https://stackoverflow.com/questions/17412304/hashing-an-array-or-object-in-python-3
def checksum(data):
import hashlib
hashId = hashlib.md5()
hashId.update(repr(data).encode("utf-8"))
return hashId.hexdigest()
def paragraphs_from_glob(paths):
paragraphs = []
paragraph_idx_to_path = {}
paragraph_idx_to_lines = {}
paths.sort() # stable checksum!
for _, path in enumerate(paths):
title = pathlib.Path(path)
new_paragraphs = [title.stem.strip()]
new_paragraphs.extend(title.read_text().split("\n\n"))
new_paragraphs = list(filter(None, new_paragraphs))
line = 1
for paragraph_idx, paragraph in zip(
range(len(paragraphs), len(paragraphs) + len(new_paragraphs)),
new_paragraphs,
):
assert paragraph_idx >= len(paragraphs)
assert paragraph
assert paragraph_idx not in paragraph_idx_to_path
paragraph_idx_to_path[paragraph_idx] = path
paragraph_idx_to_lines[paragraph_idx] = [
line,
line + paragraph.count("\n"),
]
if paragraph_idx > len(paragraphs): # not the title
line += paragraph.count("\n") + 2 # for the \n\n
paragraphs.extend(new_paragraphs)
return {
"paragraphs": paragraphs,
"paragraph_idx_to_path": paragraph_idx_to_path,
"paragraph_idx_to_lines": paragraph_idx_to_lines,
"checksum": checksum(paragraphs),
}
def cached_models():
from sentence_transformers import CrossEncoder, SentenceTransformer
cache_path = "/tmp/model-cache"
if stub.is_inside():
cache_path = "/root/models"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_path # SentenceTransformer
os.environ["TORCH_HOME"] = cache_path # CrossEncoder
before = time.monotonic()
bi_encoder = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
print("Instantiate bi and cross encoders: ", time.monotonic() - before, "seconds", file=sys.stderr)
return bi_encoder, cross_encoder
@stub.function(
gpu=True,
shared_volumes={
"/root/models": volume,
},
)
def encode_paragraph_embeddings(paragraphs):
bi_encoder, _cross_encoder = cached_models()
# print to stderr
print("Encoding paragraphs", file=sys.stderr)
before = time.monotonic()
# type: ignore
paragraph_embeddings = bi_encoder.encode(
paragraphs["paragraphs"],
# batch_size=128,
show_progress_bar=True,
convert_to_tensor=True,
)
print("Encode paragraphs: ", time.monotonic() - before, "seconds", file=sys.stderr)
return paragraph_embeddings.cpu() # for serialization
def cached_paragraphs(key) -> dict:
import diskcache
dc = diskcache.Cache("/tmp/zk-cache/")
print("Cache key is", key, file=sys.stderr)
if key in dc:
before = time.monotonic()
cached_paragraphs = dc.get(key)
print("Get cached paragraph embeddings: ", time.monotonic() - before, "seconds", file=sys.stderr)
return cached_paragraphs
else:
print("Running stub..")
with stub.run():
before = time.monotonic()
print("Getting paragraphs")
paragraphs = paragraphs_from_glob(glob.glob(os.path.expanduser(note_path), recursive=True))
print("Load paragraphs: ", time.monotonic() - before, "seconds")
paragraphs["embeddings"] = encode_paragraph_embeddings(paragraphs)
dc.set(key, paragraphs, expire=24 * 3600 * 5)
dc.expire()
return paragraphs
# @stub.function(
# gpu=False,
# shared_volumes={
# "/root/models": volume,
# },
# mounts=[
# modal.Mount(
# remote_dir="/root/notes",
# local_dir="~/Documents/Zettelkasten",
# # only allow markdown files
# condition=lambda path: path.endswith(".md"),
# recursive=True,
# )
# ],
# )
def search(query_string: str):
global paragraphs
global bi_encoder
global cross_encoder
from sentence_transformers import util
before = time.monotonic()
if paragraphs is None:
paragraphs = cached_paragraphs("expiring-paragraphs")
print("Get paragraph embeddings: ", time.monotonic() - before, "seconds", file=sys.stderr)
if bi_encoder is None:
bi_encoder, cross_encoder = cached_models()
query_string = query_string.strip()
query_embedding = bi_encoder.encode(query_string, convert_to_tensor=True)
before = time.monotonic()
results = util.semantic_search(query_embedding, paragraphs["embeddings"], top_k=32)[0]
print("Search time: ", time.monotonic() - before, file=sys.stderr)
deduped_results = {}
before = time.monotonic()
cross_inp = [[query_string, paragraphs["paragraphs"][result["corpus_id"]]] for result in reversed(results)]
cross_scores = cross_encoder.predict(cross_inp)
print("Cross encoding time: ", time.monotonic() - before, "seconds\n", file=sys.stderr)
for idx, match in enumerate(reversed(results)):
paragraph_id = match["corpus_id"]
path = paragraphs["paragraph_idx_to_path"][paragraph_id]
# paragraph = paragraphs[paragraph_id]
start_line, end_line = paragraphs["paragraph_idx_to_lines"][paragraph_id]
dirname = os.path.basename(os.path.dirname(path))
title = os.path.basename(path)
if dirname == "highlights":
title = f"highlights/{title}"
cross_score = cross_scores[idx]
# score = match["score"]
score = cross_score
# cross_encoder_score = cross_encoder.
title_with_lines = f"{start_line}:{end_line} {max(start_line - 10, 0)} {title}"
if path not in deduped_results:
deduped_results[title] = [score, title_with_lines]
elif score > deduped_results[title][0]:
deduped_results[title] = [
score,
title_with_lines,
]
else:
# it was another match, but score wasn't higher... should prob contribute to score though
# similar to TF
continue
output = ""
for _, match in sorted(deduped_results.items(), key=lambda item: item[1]):
output += f"{match[0]:.2f} {match[1]}\n"
# print("{:.2f} {}".format(match[0], match[1]))
return output[0:-1]
def port_open(port):
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(("127.0.0.1", port))
sock.close()
return result == 0
def simple_tcp_server():
with open(os.devnull, "w") as f1, open(os.devnull, "w") as f2:
sys.stderr = f1
sys.stdout = f2
port = 8090
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("127.0.0.1", port))
s.listen(5)
def signal_handler(signal, frame):
s.close()
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
print("Listening on port", port, file=sys.stderr)
while True:
c, addr = s.accept()
print("PID: ", os.getpid(), file=sys.stderr)
query = c.recv(1024)
if len(query) > 0: # because of the port check!
result = search(query.decode("utf-8"))
c.send(result.encode("utf-8"))
c.close()
def query_via_tcp(query):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(("localhost", 8090))
s.send(query.encode("utf-8"))
print(s.recv(10_000).decode("utf-8"))
s.close()
if __name__ == "__main__":
query = " ".join(sys.argv[1:])
# if query is empty
if not query:
print("No query was provided")
sys.exit(0)
if port_open(8090):
query_via_tcp(query)
else:
p = multiprocessing.Process(target=simple_tcp_server, daemon=False, name="python search.py")
p.start()
for _ in range(500):
time.sleep(0.1)
if port_open(8090):
break
query_via_tcp(query)
os.kill(os.getpid(), signal.SIGTERM)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment