Last active
May 14, 2023 23:42
-
-
Save masuidrive/ec0fbdb6f8f99632ecef299cdabd9516 to your computer and use it in GitHub Desktop.
Faiss-SentenceTransformers-demo1.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPzOvBZP8LLQng+t0sN87GZ", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/masuidrive/ec0fbdb6f8f99632ecef299cdabd9516/faiss-sentencetransformers-demo1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "nTRQiuJL3grO", | |
"outputId": "2ef13b3d-3328-4fb5-e2e2-29907b33cdec" | |
}, | |
"outputs": [ | |
], | |
"source": [ | |
"! pip install faiss-cpu sentence-transformers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import sqlite3\n", | |
"import numpy as np\n", | |
"from sentence_transformers import SentenceTransformer\n", | |
"import faiss\n", | |
"\n", | |
"# 1. SQLiteデータベースの準備\n", | |
"def create_faq_database():\n", | |
" conn = sqlite3.connect(\"faq.db\")\n", | |
" cursor = conn.cursor()\n", | |
" cursor.execute(\"CREATE TABLE IF NOT EXISTS faq (id INTEGER PRIMARY KEY, question TEXT)\")\n", | |
" cursor.execute(\"DELETE FROM faq\")\n", | |
"\n", | |
" # サンプルデータを追加\n", | |
" sample_faqs = [\n", | |
" \"How do I reset my password?\",\n", | |
" \"What payment methods do you accept?\",\n", | |
" \"Can I return an item I bought?\",\n", | |
" \"How can I track my order?\"\n", | |
" ]\n", | |
"\n", | |
" cursor.executemany(\"INSERT INTO faq (question) VALUES (?)\", [(q,) for q in sample_faqs])\n", | |
" conn.commit()\n", | |
" return conn\n", | |
"\n", | |
"# 2. 文章の埋め込みの作成\n", | |
"def compute_faq_embeddings(faq_questions, model):\n", | |
" return np.array(model.encode(faq_questions))\n", | |
"\n", | |
"# 3. Faissインデックスの作成\n", | |
"def create_faiss_index(model, embeddings, doc_ids):\n", | |
" dimension = model.get_sentence_embedding_dimension()\n", | |
" index_flat_l2 = faiss.IndexFlatL2(dimension)\n", | |
" index = faiss.IndexIDMap(index_flat_l2)\n", | |
" index.add_with_ids(embeddings, doc_ids)\n", | |
" return index, index_flat_l2\n", | |
"\n", | |
"# 4. クエリ処理\n", | |
"def search_faq(query, model, index, k=3):\n", | |
" query_embedding = model.encode([query])\n", | |
" distances, indices = index.search(np.array(query_embedding), k)\n", | |
" return indices[0]\n", | |
"\n", | |
"# 5. 結果の取得\n", | |
"def get_faq_results(faq_ids, conn):\n", | |
" cursor = conn.cursor()\n", | |
" cursor.execute(\"SELECT * FROM faq WHERE id IN ({})\".format(\",\".join(\"?\" * len(faq_ids))), faq_ids)\n", | |
" return cursor.fetchall()\n", | |
"\n", | |
"# データベース作成\n", | |
"conn = create_faq_database()\n", | |
"\n", | |
"# sentence-transformersモデルを読み込む\n", | |
"# model = SentenceTransformer(\"sentence-transformers/paraphrase-xlm-r-multilingual-v1\")\n", | |
"model = SentenceTransformer(\"sentence-transformers/paraphrase-distilroberta-base-v1\")\n", | |
"\n", | |
"# FAQデータを読み込み、埋め込みベクトルを計算\n", | |
"cursor = conn.cursor()\n", | |
"cursor.execute(\"SELECT id, question FROM faq\")\n", | |
"faq_data = cursor.fetchall()\n", | |
"faq_ids, faq_questions = zip(*faq_data)\n", | |
"faq_embeddings = compute_faq_embeddings(faq_questions, model)\n", | |
"\n", | |
"# Faissインデックスを作成\n", | |
"index, index_flat_l2 = create_faiss_index(model, faq_embeddings, faq_ids)\n", | |
"\n", | |
"# クエリを入力して検索\n", | |
"query = \"How can I change my password?\"\n", | |
"faq_indices = search_faq(query, model, index_flat_l2)\n", | |
"\n", | |
"# 結果を取得して表示\n", | |
"results = get_faq_results([faq_ids[i] for i in faq_indices], conn)\n", | |
"print(\"Results for query:\", query)\n", | |
"for r in results:\n", | |
" print(f\"ID: {r[0]}, Question: {r[1]}\")\n", | |
"\n", | |
"# データベース接続を閉じる\n", | |
"# conn.close()\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "A_Wug6ED4RrG", | |
"outputId": "73236b5b-02b3-45cd-cafb-caad7cd32e59" | |
}, | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Results for query: How can I change my password?\n", | |
"ID: 1, Question: How do I reset my password?\n", | |
"ID: 2, Question: What payment methods do you accept?\n", | |
"ID: 4, Question: How can I track my order?\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def add_faq(question, conn, model, index):\n", | |
" cursor = conn.cursor()\n", | |
" cursor.execute(\"INSERT INTO faq (question) VALUES (?)\", (question,))\n", | |
" faq_id = cursor.lastrowid\n", | |
" conn.commit()\n", | |
"\n", | |
" embedding = model.encode([question])\n", | |
" index.add_with_ids(np.array(embedding), np.array([faq_id]))\n", | |
"\n", | |
" return faq_id\n", | |
"\n", | |
"def remove_faq(faq_id, conn, index):\n", | |
" cursor = conn.cursor()\n", | |
" cursor.execute(\"DELETE FROM faq WHERE id=?\", (faq_id,))\n", | |
" conn.commit()\n", | |
"\n", | |
" # FaissインデックスからIDを削除\n", | |
" id_selector = faiss.IDSelectorArray(np.array([faq_id], dtype=np.int64))\n", | |
" index.remove_ids(id_selector)\n", | |
"\n", | |
"def get_answer(query, model, index):\n", | |
" faq_indices = search_faq(query, model, index)\n", | |
" results = get_faq_results([faq_ids[i] for i in faq_indices], conn)\n", | |
" return results\n", | |
"\n", | |
"# FAQを追加\n", | |
"new_question = \"What is your refund policy?\"\n", | |
"new_faq_id = add_faq(new_question, conn, model, index)\n", | |
"print(f\"Added new FAQ with ID: {new_faq_id}\")\n", | |
"\n", | |
"# 確認\n", | |
"query = \"What's the refund polcy?\"\n", | |
"faq_indices = search_faq(query, model, index)\n", | |
"print(f\"{query}: {faq_indices}\")\n", | |
"\n", | |
"# FAQを削除\n", | |
"faq_id_to_remove = new_faq_id\n", | |
"remove_faq(faq_id_to_remove, conn, index)\n", | |
"print(f\"Removed FAQ with ID: {faq_id_to_remove}\")\n", | |
"\n", | |
"# 確認\n", | |
"query = \"What's the refund polcy?\"\n", | |
"faq_indices = search_faq(query, model, index)\n", | |
"print(f\"{query}: {faq_indices}\")\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "58fuA0234SOf", | |
"outputId": "bae8c6e3-00bb-4608-e8eb-dd018d0137c3" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Added new FAQ with ID: 5\n", | |
"What's the refund polcy?: [5 2 3]\n", | |
"Removed FAQ with ID: 5\n", | |
"What's the refund polcy?: [2 3 1]\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment