Skip to content

Instantly share code, notes, and snippets.

Created Mar 14, 2021
What would you like to do?
Streamlit Script that Cache the loading of a FAISS index (live at
import os
import sqlite3
import datetime
from typing import List
import faiss
import numpy as np
import pandas as pd
import joblib
import requests
import streamlit as st
os.environ["TOKENIZERS_PARALLELISM"] = "false"
api_uri = os.environ.get("API_URI", "http://localhost:8666/")
def load_data():
conn = sqlite3.connect("data/news.sqlite")
full_ids = joblib.load("data/ids.jbl")
index = faiss.read_index("data/index.faiss")
default_date_range = [, 11, 28),]
return conn, full_ids, index, default_date_range
def fetch_entries(conn: sqlite3.Connection, ids: List, date_range, scores):
cur = conn.cursor()
"SELECT id, date, title, desc FROM entries " +
"WHERE id IN ({seq}) ".format(
) +
" AND date >= ? AND date <= ?;",
ids + [x.isoformat() for x in date_range]
results = pd.DataFrame(
cur.fetchall(), columns=["id", "date", "title", "desc"]
results["date"] = pd.to_datetime(results["date"])
score_dict = {key: score for key, score in zip(ids, scores)}
results["score"] = results["id"].apply(lambda x: score_dict[x])
results.sort_values("score", ascending=False, inplace=True)
return results
def get_latest_date(conn: sqlite3.Connection):
cur = conn.cursor()
cur.execute("SELECT MAX(date) FROM entries;")
return cur.fetchone()[0]
def get_embeddings(text: str):
response =, json={"text": text})
assert response.status_code == 200, response.text
return np.asarray(response.json()["vector"])[np.newaxis, :].astype("float32")
def main():
st.title('Veritable News Semantic Search Engine')
query = st.text_area(
"Context/主題 (length > 10)", "", max_chars=256
).replace("\n", " ")
conn, full_ids, index, default_date_range = load_data()
if st.button("Last 90 days"):
default_date_range[0] = - \
if st.button("From the start"):
default_date_range[0] =, 11, 28)
date_range = st.date_input(
'Date Range/日期範圍',
value=default_date_range,, 11, 28), + datetime.timedelta(days=1)
sort_method = st.selectbox(
"Sort by:", ("relevance", "date (desc)", "date (asc)"))
if len(query) > 10 and len(date_range) == 2:
embs = get_embeddings(query)
scores, index_matches =, k=100)
df_entries = fetch_entries(
conn, [full_ids[i]
for i in index_matches[0]], date_range, scores[0]
if sort_method != "relevance":
"date", ascending=sort_method == "date (asc)", inplace=True
for row in df_entries.values:
date, section, num = row[0].split("_")
f"{date[:4]}/{date[4:6]}/{date[6:8]} "
f"[{row[2]}]({date}/{section}_{num}) (score: {row[-1]:.4f})"
latest_date = get_latest_date(conn)
f"Data updated on _{latest_date}_; Engine updated on _2021-03-07_")
st.write("_This app uses a TinyBERT-4L model to reduce hardware requirements. If you're interested in bigger and more powerful models, please e-mail **ceshine at**_")
if __name__ == "__main__":
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment