Skip to content

Instantly share code, notes, and snippets.

@masuidrive
Last active May 14, 2023 23:42
Show Gist options
  • Save masuidrive/ec0fbdb6f8f99632ecef299cdabd9516 to your computer and use it in GitHub Desktop.
Save masuidrive/ec0fbdb6f8f99632ecef299cdabd9516 to your computer and use it in GitHub Desktop.
Faiss-SentenceTransformers-demo1.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPzOvBZP8LLQng+t0sN87GZ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/masuidrive/ec0fbdb6f8f99632ecef299cdabd9516/faiss-sentencetransformers-demo1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nTRQiuJL3grO",
"outputId": "2ef13b3d-3328-4fb5-e2e2-29907b33cdec"
},
"outputs": [
],
"source": [
"! pip install faiss-cpu sentence-transformers"
]
},
{
"cell_type": "code",
"source": [
"import sqlite3\n",
"import numpy as np\n",
"from sentence_transformers import SentenceTransformer\n",
"import faiss\n",
"\n",
"# 1. SQLiteデータベースの準備\n",
"def create_faq_database():\n",
" conn = sqlite3.connect(\"faq.db\")\n",
" cursor = conn.cursor()\n",
" cursor.execute(\"CREATE TABLE IF NOT EXISTS faq (id INTEGER PRIMARY KEY, question TEXT)\")\n",
" cursor.execute(\"DELETE FROM faq\")\n",
"\n",
" # サンプルデータを追加\n",
" sample_faqs = [\n",
" \"How do I reset my password?\",\n",
" \"What payment methods do you accept?\",\n",
" \"Can I return an item I bought?\",\n",
" \"How can I track my order?\"\n",
" ]\n",
"\n",
" cursor.executemany(\"INSERT INTO faq (question) VALUES (?)\", [(q,) for q in sample_faqs])\n",
" conn.commit()\n",
" return conn\n",
"\n",
"# 2. 文章の埋め込みの作成\n",
"def compute_faq_embeddings(faq_questions, model):\n",
" return np.array(model.encode(faq_questions))\n",
"\n",
"# 3. Faissインデックスの作成\n",
"def create_faiss_index(model, embeddings, doc_ids):\n",
" dimension = model.get_sentence_embedding_dimension()\n",
" index_flat_l2 = faiss.IndexFlatL2(dimension)\n",
" index = faiss.IndexIDMap(index_flat_l2)\n",
" index.add_with_ids(embeddings, doc_ids)\n",
" return index, index_flat_l2\n",
"\n",
"# 4. クエリ処理\n",
"def search_faq(query, model, index, k=3):\n",
" query_embedding = model.encode([query])\n",
" distances, indices = index.search(np.array(query_embedding), k)\n",
" return indices[0]\n",
"\n",
"# 5. 結果の取得\n",
"def get_faq_results(faq_ids, conn):\n",
" cursor = conn.cursor()\n",
" cursor.execute(\"SELECT * FROM faq WHERE id IN ({})\".format(\",\".join(\"?\" * len(faq_ids))), faq_ids)\n",
" return cursor.fetchall()\n",
"\n",
"# データベース作成\n",
"conn = create_faq_database()\n",
"\n",
"# sentence-transformersモデルを読み込む\n",
"# model = SentenceTransformer(\"sentence-transformers/paraphrase-xlm-r-multilingual-v1\")\n",
"model = SentenceTransformer(\"sentence-transformers/paraphrase-distilroberta-base-v1\")\n",
"\n",
"# FAQデータを読み込み、埋め込みベクトルを計算\n",
"cursor = conn.cursor()\n",
"cursor.execute(\"SELECT id, question FROM faq\")\n",
"faq_data = cursor.fetchall()\n",
"faq_ids, faq_questions = zip(*faq_data)\n",
"faq_embeddings = compute_faq_embeddings(faq_questions, model)\n",
"\n",
"# Faissインデックスを作成\n",
"index, index_flat_l2 = create_faiss_index(model, faq_embeddings, faq_ids)\n",
"\n",
"# クエリを入力して検索\n",
"query = \"How can I change my password?\"\n",
"faq_indices = search_faq(query, model, index_flat_l2)\n",
"\n",
"# 結果を取得して表示\n",
"results = get_faq_results([faq_ids[i] for i in faq_indices], conn)\n",
"print(\"Results for query:\", query)\n",
"for r in results:\n",
" print(f\"ID: {r[0]}, Question: {r[1]}\")\n",
"\n",
"# データベース接続を閉じる\n",
"# conn.close()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A_Wug6ED4RrG",
"outputId": "73236b5b-02b3-45cd-cafb-caad7cd32e59"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Results for query: How can I change my password?\n",
"ID: 1, Question: How do I reset my password?\n",
"ID: 2, Question: What payment methods do you accept?\n",
"ID: 4, Question: How can I track my order?\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def add_faq(question, conn, model, index):\n",
" cursor = conn.cursor()\n",
" cursor.execute(\"INSERT INTO faq (question) VALUES (?)\", (question,))\n",
" faq_id = cursor.lastrowid\n",
" conn.commit()\n",
"\n",
" embedding = model.encode([question])\n",
" index.add_with_ids(np.array(embedding), np.array([faq_id]))\n",
"\n",
" return faq_id\n",
"\n",
"def remove_faq(faq_id, conn, index):\n",
" cursor = conn.cursor()\n",
" cursor.execute(\"DELETE FROM faq WHERE id=?\", (faq_id,))\n",
" conn.commit()\n",
"\n",
" # FaissインデックスからIDを削除\n",
" id_selector = faiss.IDSelectorArray(np.array([faq_id], dtype=np.int64))\n",
" index.remove_ids(id_selector)\n",
"\n",
"def get_answer(query, model, index):\n",
" faq_indices = search_faq(query, model, index)\n",
" results = get_faq_results([faq_ids[i] for i in faq_indices], conn)\n",
" return results\n",
"\n",
"# FAQを追加\n",
"new_question = \"What is your refund policy?\"\n",
"new_faq_id = add_faq(new_question, conn, model, index)\n",
"print(f\"Added new FAQ with ID: {new_faq_id}\")\n",
"\n",
"# 確認\n",
"query = \"What's the refund polcy?\"\n",
"faq_indices = search_faq(query, model, index)\n",
"print(f\"{query}: {faq_indices}\")\n",
"\n",
"# FAQを削除\n",
"faq_id_to_remove = new_faq_id\n",
"remove_faq(faq_id_to_remove, conn, index)\n",
"print(f\"Removed FAQ with ID: {faq_id_to_remove}\")\n",
"\n",
"# 確認\n",
"query = \"What's the refund polcy?\"\n",
"faq_indices = search_faq(query, model, index)\n",
"print(f\"{query}: {faq_indices}\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "58fuA0234SOf",
"outputId": "bae8c6e3-00bb-4608-e8eb-dd018d0137c3"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Added new FAQ with ID: 5\n",
"What's the refund polcy?: [5 2 3]\n",
"Removed FAQ with ID: 5\n",
"What's the refund polcy?: [2 3 1]\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment