Skip to content

Instantly share code, notes, and snippets.

@ceshine
Created Mar 14, 2021
Embed
What would you like to do?
Streamlit Script that Cache the loading of a FAISS index (live at https://news-search.veritable.pw)
import os
import sqlite3
import datetime
from typing import List
import faiss
import numpy as np
import pandas as pd
import joblib
import requests
import streamlit as st
os.environ["TOKENIZERS_PARALLELISM"] = "false"
api_uri = os.environ.get("API_URI", "http://localhost:8666/")
@st.cache(allow_output_mutation=True)
def load_data():
conn = sqlite3.connect("data/news.sqlite")
full_ids = joblib.load("data/ids.jbl")
index = faiss.read_index("data/index.faiss")
default_date_range = [datetime.date(2018, 11, 28), datetime.date.today()]
return conn, full_ids, index, default_date_range
def fetch_entries(conn: sqlite3.Connection, ids: List, date_range, scores):
cur = conn.cursor()
cur.execute(
"SELECT id, date, title, desc FROM entries " +
"WHERE id IN ({seq}) ".format(
seq=','.join(['?']*len(ids))
) +
" AND date >= ? AND date <= ?;",
ids + [x.isoformat() for x in date_range]
)
results = pd.DataFrame(
cur.fetchall(), columns=["id", "date", "title", "desc"]
)
results["date"] = pd.to_datetime(results["date"])
score_dict = {key: score for key, score in zip(ids, scores)}
results["score"] = results["id"].apply(lambda x: score_dict[x])
results.sort_values("score", ascending=False, inplace=True)
return results
def get_latest_date(conn: sqlite3.Connection):
cur = conn.cursor()
cur.execute("SELECT MAX(date) FROM entries;")
return cur.fetchone()[0]
def get_embeddings(text: str):
response = requests.post(api_uri, json={"text": text})
assert response.status_code == 200, response.text
return np.asarray(response.json()["vector"])[np.newaxis, :].astype("float32")
def main():
st.title('Veritable News Semantic Search Engine')
query = st.text_area(
"Context/主題 (length > 10)", "", max_chars=256
).replace("\n", " ")
conn, full_ids, index, default_date_range = load_data()
if st.button("Last 90 days"):
default_date_range[0] = datetime.date.today() - \
datetime.timedelta(days=90)
if st.button("From the start"):
default_date_range[0] = datetime.date(2018, 11, 28)
date_range = st.date_input(
'Date Range/日期範圍',
value=default_date_range,
min_value=datetime.date(2018, 11, 28),
max_value=datetime.date.today() + datetime.timedelta(days=1)
)
sort_method = st.selectbox(
"Sort by:", ("relevance", "date (desc)", "date (asc)"))
if len(query) > 10 and len(date_range) == 2:
embs = get_embeddings(query)
faiss.normalize_L2(embs)
scores, index_matches = index.search(embs, k=100)
df_entries = fetch_entries(
conn, [full_ids[i]
for i in index_matches[0]], date_range, scores[0]
).iloc[:20]
if sort_method != "relevance":
df_entries.sort_values(
"date", ascending=sort_method == "date (asc)", inplace=True
)
for row in df_entries.values:
date, section, num = row[0].split("_")
st.write(
f"{date[:4]}/{date[4:6]}/{date[6:8]} "
f"[{row[2]}](https://news.veritable.pw/zh/piece/{date}/{section}_{num}) (score: {row[-1]:.4f})"
)
latest_date = get_latest_date(conn)
st.write(
f"Data updated on _{latest_date}_; Engine updated on _2021-03-07_")
st.write("_This app uses a TinyBERT-4L model to reduce hardware requirements. If you're interested in bigger and more powerful models, please e-mail **ceshine at veritable.pw**_")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment