Skip to content

Instantly share code, notes, and snippets.

@johtani
Created December 17, 2024 14:41
Show Gist options
  • Save johtani/a68b0d63b2bc6aec9c1ee32577206f3e to your computer and use it in GitHub Desktop.
Save johtani/a68b0d63b2bc6aec9c1ee32577206f3e to your computer and use it in GitHub Desktop.
MS COCOのデータをWeaviateで検索する画面をStreamlitで構成
from weaviate_connection import WeaviateConnection
import streamlit as st
def main():
collection_name = "MultiModalKagome"
properties_kagome = ["caption_ja"]
properties_gse = ["caption_ja_gse"]
limit = 20
st.title("Multi Modal/Language CLIP Hybrid Search")
# create connection with Weaviate
conn = st.connection(
"weaviate", type=WeaviateConnection, host="host.docker.internal"
)
if conn.client().is_ready():
print("Connected Weaviate server!")
else:
st.caption("接続に問題がありそうです")
if conn.exists(collection_name=collection_name) is False:
st.caption("コレクションの準備ができいないようです")
st.write(
f":red[{conn.total_count(collection_name=collection_name)}] images you can search"
)
search_type = st.radio(
label="検索タイプ",
options=[
"**bm25-gse**",
"**bm25-kagome**",
"**vector**",
"**hybrid**",
"**vector+filter by kagome**",
"**vector+rerank by kagome**",
],
horizontal=True,
label_visibility="collapsed",
)
# setup search box
query_text = st.text_input("Put your words what what you want?", "", key="query")
filter_text = st.text_input(
"Filter word?",
"",
key="filter",
)
rerank_text = st.text_input(
"Rerank word?",
"",
key="rerank",
)
result = st.container()
if query_text:
result = st.container()
# perform search
if search_type == "**bm25-gse**":
hits = conn.query(
collection_name=collection_name,
query=query_text,
query_properties=properties_gse,
limit=limit,
)
elif search_type == "**bm25-kagome**":
hits = conn.query(
collection_name=collection_name,
query=query_text,
query_properties=properties_kagome,
limit=limit,
)
elif search_type == "**vector**":
hits = conn.near_query(
collection_name=collection_name,
query=query_text,
query_properties=properties_kagome,
limit=limit,
)
elif search_type == "**vector+filter by kagome**":
hits = conn.near_query(
collection_name=collection_name,
query=query_text,
query_properties=properties_kagome,
limit=limit,
filter=filter_text,
)
elif search_type == "**vector+rerank by kagome**":
hits = conn.near_query(
collection_name=collection_name,
query=query_text,
query_properties=properties_kagome,
limit=limit,
rerank=rerank_text,
)
else:
hits = conn.hybrid_query(
collection_name=collection_name,
query=query_text,
query_properties=properties_kagome,
limit=limit,
)
if len(hits) > 0:
result.divider()
for hit in hits:
result.image(f"../../images/{hit.properties["filename"]}")
result.write(f"Flickr : {hit.properties["flickr_url"]}")
if search_type.startswith("**vector"):
result.write(f"Score is {hit.metadata.distance:.3f}")
else:
result.write(f"Score is {hit.metadata.score:.3f}")
result.table(hit.properties["caption_ja"])
result.divider()
else:
result.write("No results...")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment