Skip to content

Instantly share code, notes, and snippets.

@failable
Last active June 19, 2024 03:15
Show Gist options
  • Save failable/0379edf7a5d82024a69a50194295372f to your computer and use it in GitHub Desktop.
Save failable/0379edf7a5d82024a69a50194295372f to your computer and use it in GitHub Desktop.
Qdrant viewer
import os
import taipy.gui.builder as tgb
from openai import OpenAI
from qdrant_client import QdrantClient
from qdrant_client.models import PayloadFieldSchema
from taipy.gui import Gui
OPENAI_EMBEDDING_MODEL = "text-embedding-ada-002"
def create_embeddings(
client: OpenAI,
doc: str,
model: str = "text-embedding-ada-002",
**kwargs,
) -> list[list[dict]]:
response = client.embeddings.create(input=doc, model=model, **kwargs)
return response.data[0].embedding
def create_filter(field: str, schema: PayloadFieldSchema, value: str) -> dict:
if schema.data_type == "text":
return {
"key": field,
"match": {
"text": value,
},
}
if schema.data_type == "keyword":
return {
"key": field,
"match": {
"value": value,
},
}
msg = f"Unsupported schema: {schema.data_type}"
raise ValueError(msg)
def get_collection_names(qdrant_client: QdrantClient) -> list[str]:
return sorted([x.name for x in qdrant_client.get_collections().collections])
def get_payload_schema(qdrant_client: QdrantClient, collection_name: str) -> dict:
collection_info = qdrant_client.get_collection(collection_name)
filterable_payload_types = ["keyword", "text"]
payload_schema = {
k: v
for k, v in collection_info.payload_schema.items()
if v.data_type in filterable_payload_types
}
# NOTE Fixed field ordering for better user experience.
return dict(
sorted(payload_schema.items(), key=lambda x: (x[1].data_type, x[0])),
)
def on_qdrant_url_change(state, var, val):
global qdrant_client
qdrant_client = QdrantClient(val)
def on_collection_names_change(state, var, val):
refresh_filters(state)
def refresh_filters(state):
if state.qdrant_client and state.collection_name:
payload_schema = get_payload_schema(
state.qdrant_client,
state.collection_name,
)
if payload_schema:
with tgb.Page() as filter_part:
for field, schema in payload_schema.items():
tgb.input(f"{field} ({str.capitalize(schema.data_type)})")
state.filter_partial.update_content(state, filter_part)
else:
state.filter_partial.update_content(state, "")
def on_init(state):
refresh_filters(state)
if __name__ == "__main__":
qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333")
qdrant_client = QdrantClient(qdrant_url)
collection_names = get_collection_names(qdrant_client)
collection_name = collection_names[0] if collection_names else None
with tgb.Page() as page:
with tgb.expandable("Options"):
qdrant_url_input = tgb.input(
value="{qdrant_url}",
label="Qdrant url",
on_change=on_qdrant_url_change,
)
with tgb.layout(columns="1 1"):
collection_names_selector = tgb.selector(
value="{collection_name}",
label="Collection",
lov=collection_names,
dropdown=True,
on_change=on_collection_names_change,
)
tgb.input(value=10, label="Number of results")
tgb.part(partial="{filter_partial}")
gui = Gui(page=page)
filter_partial = gui.add_partial("filter_partial")
gui.run(
title="Qdrant viewer",
dark_mode=False,
debug=True,
use_reloader=True,
port=5001,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment