Skip to content

Instantly share code, notes, and snippets.

@Vesnica
Last active June 19, 2023 09:34
Show Gist options
  • Save Vesnica/c96115ba744edbe6735a9d476abc8002 to your computer and use it in GitHub Desktop.
Save Vesnica/c96115ba744edbe6735a9d476abc8002 to your computer and use it in GitHub Desktop.
Qdrant Json Embedding
import pathlib
import json
import itertools
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient, models
m = SentenceTransformer("shibing624/text2vec-base-chinese-nli")
client_mem = QdrantClient(":memory:")
client_db = QdrantClient(path="qdrant")
song_poetry = []
dir = 'chinese-poetry/宋词/'
for p in pathlib.Path(dir).glob('ci.song.*.json'):
with open(str(p)) as f:
song_poetry.append(json.load(f))
song_poetry = list(itertools.chain(*song_poetry))
paragraphs = [''.join(s['paragraphs']) for s in song_poetry]
paragraphs_vec = m.encode(paragraphs)
client_db.recreate_collection(
collection_name="song_poetry",
vectors_config=models.VectorParams(
size=m.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
)
)
records = [
models.Record(
id=i,
vector=paragraphs_vec[i].tolist(),
payload=p,
) for i, p in enumerate(song_poetry)
]
client_db.upload_records(
collection_name="song_poetry",
records=records,
)
hits = client_db.search(
collection_name="song_poetry",
query_vector=m.encode("月亮").tolist(),
limit=3
)
for hit in hits:
print(hit.payload, "score:", hit.score)
hits = client_db.search(
collection_name="song_poetry",
query_vector=m.encode("月亮").tolist(),
query_filter=models.Filter(
must=[
models.FieldCondition(
key="author",
match=models.MatchAny(any=["陆游"]),
)
]
),
limit=3
)
for hit in hits:
print(hit.payload, "score:", hit.score)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment