Skip to content

Instantly share code, notes, and snippets.

@Sh1n0g1
Created May 20, 2023 01:29
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sh1n0g1/eee5daef2ab9dacc14348c095620630e to your computer and use it in GitHub Desktop.
Save Sh1n0g1/eee5daef2ab9dacc14348c095620630e to your computer and use it in GitHub Desktop.
ShinoLang.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOfCE3BG3+3zC/BLFpR/z4e",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Sh1n0g1/eee5daef2ab9dacc14348c095620630e/shinolang.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# ShinoLang\n",
"ドキュメントの学習および質問ができるシステム(LangChain + FLASS + ChatGPT)\n",
"読み込めるドキュメントは以下の通りである\n",
"* PDF\n",
"* Word文書 (.doc/.docx) ※ほとんど検証していない\n",
"* HTMLファイル (.htm/.html)\n",
"* その他のテキストファイル(.txt/\\*.\\*)\n",
"## 使い方(初回)\n",
"1. 上のメニューから`編集`から`出力をすべて消去`をクリックする\n",
"1. 真下の「1.設定」のフォームを埋める\n",
"1. 「5.質問」を埋める\n",
"1. 上のメニューから`ランタイム`の`すべてのセルを実行`をクリックする\n",
"1.「4.学習」で処理が止まるので、指定されたディレクトリに文書をアップロードする\n",
"1. 最終的に質問に対する答えが`5.1`で表示される\n",
"## 制限事項\n",
"* マルチバイトの文字を含むHTMLファイルはエラーになる場合があります。ブラウザの印刷機能でPDFに変換して読み込むのが良さそうです。\n",
"\n"
],
"metadata": {
"id": "uVx7-oyKBBut"
}
},
{
"cell_type": "code",
"source": [
"#@title 1.設定\n",
"import os\n",
"import sys\n",
"from IPython.display import clear_output\n",
"\n",
"#@markdown ### Googleドライブの連携\n",
"#@markdown 学習させた内容を次回以降も利用する場合は連携が必要となる\n",
"#@markdown 「ShinoLang」というディレクトリができる\n",
"\n",
"enable_google_drive = True #@param {type:\"boolean\"}\n",
"\n",
"#@markdown ### ChatGPTのAPIキー <font color='red'>*</font>\n",
"chatgpt_api_key = \"\" #@param {type:\"string\"}\n",
"\n",
"#@markdown ### ドキュメントを入れるフォルダ名 <font color='red'>*</font>\n",
"data_dir = \"documents\" #@param {type:\"string\"}\n",
"\n",
"#@markdown ### パラメータ\n",
"#@markdown よくわからない場合はデフォルトのままにする\n",
"\n",
"chunk_size = 500 #@param {type:\"integer\"}\n",
"chunk_overlap_size = 50 #@param {type:\"integer\"}\n",
"num_of_chunk_send_to_chatgpt = 10 #@param {type:\"integer\"}\n",
"chain_type = \"stuff\" #@param {type:\"string\"}\n",
"\n",
"\n",
"if chatgpt_api_key==\"\" or data_dir==\"\":\n",
" print(\"\\033[91m入力必須(*)の設定を入れてください。\")\n",
"\n",
" sys.exit()\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = chatgpt_api_key\n",
"\n",
"clear_output()\n",
"print('\\033[92m\\u2714 Done')\n",
"\n"
],
"metadata": {
"cellView": "form",
"id": "ZS73dOyGAleW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ugiW3SYv-q00"
},
"outputs": [],
"source": [
"#@title 2.必要モジュールのインストール&インポート\n",
"\n",
"!pip install pypdf\n",
"!pip install langchain\n",
"!pip install openai\n",
"!pip install faiss-cpu\n",
"!pip install tiktoken\n",
"\n",
"from IPython.display import display, Markdown\n",
"from google.colab import drive\n",
"from langchain.vectorstores import FAISS\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.document_loaders import TextLoader\n",
"from langchain.document_loaders import PyPDFLoader\n",
"from langchain.document_loaders import Docx2txtLoader\n",
"from langchain.document_loaders import BSHTMLLoader\n",
"from langchain.chains import RetrievalQA\n",
"from langchain.llms import OpenAI\n",
"from langchain.callbacks import get_openai_callback\n",
"\n",
"clear_output()\n",
"print('\\033[92m\\u2714 Done')\n"
]
},
{
"cell_type": "code",
"source": [
"#@title 3.パスの設定\n",
"GOOGLE_DRIVE_PATH='/content/gdrive/'\n",
"ROOT_DIR='ShinoLang/'\n",
"LOADED_DIR_SUFFIX=\"_loaded/\"\n",
"FAILED_DIR_SUFFIX=\"_failed/\"\n",
"DB_DIR_SUFFIX=\"_db/\"\n",
"\n",
"if enable_google_drive:\n",
" drive.mount(GOOGLE_DRIVE_PATH)\n",
" install_path=GOOGLE_DRIVE_PATH + 'MyDrive/' + ROOT_DIR\n",
"else:\n",
" install_path='/content/' + ROOT_DIR\n",
"if not os.path.exists(install_path):\n",
" os.mkdir(install_path)\n",
"\n",
"os.chdir(install_path)\n",
"if data_dir.endswith(\"/\"):\n",
" data_dir=data_dir[:-1]\n",
"\n",
"# Data Path\n",
"data_path=install_path + data_dir\n",
"if not os.path.exists(data_path):\n",
" os.mkdir(data_path)\n",
"\n",
"# Loaded Path\n",
"loaded_path=install_path + data_dir + LOADED_DIR_SUFFIX\n",
"if not os.path.exists(loaded_path):\n",
" os.mkdir(loaded_path)\n",
"\n",
"# Failed Path\n",
"failed_path=install_path + data_dir + FAILED_DIR_SUFFIX\n",
"if not os.path.exists(failed_path):\n",
" os.mkdir(failed_path)\n",
"\n",
"# DB Path\n",
"db_path=install_path + data_dir + DB_DIR_SUFFIX\n",
"if not os.path.exists(db_path):\n",
" os.mkdir(db_path)\n",
"\n",
"print(f\"Install Path: {install_path}\")\n",
"print(f\" Data Path: {data_path}\")\n",
"print(f\" Loaded Path: {loaded_path}\")\n",
"print(f\" Failed Path: {failed_path}\")\n",
"print(f\" DB Path: {db_path}\")\n",
"print('\\033[92m\\u2714 Done')\n"
],
"metadata": {
"cellView": "form",
"id": "_zMo447r_H5C"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 4.1 学習(ファイルのアップロード)\n",
"#@markdown ファイルのアップロード\n",
"#@markdown * 学習させたいファイルを上記の `Data Path`にアップロードする。\n",
"#@markdown * ディレクトリが表示されない場合はディレクトリツリーから`更新`をする。\n",
"#@markdown * Googleドライブを利用している場合はそっちから直接アップロードしたほうが早い。\n",
"#@markdown * ファイルのアップロードが完了または学習をスキップするには[Enter]を押して続行する。\n",
"\n",
"input(\"Enterを入力してください\")\n",
"\n",
"\n"
],
"metadata": {
"cellView": "form",
"id": "Ms0eaoPZU1aa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 4.2 学習(実行)\n",
"#@markdown 学習を行います。大きい文書だと数分かかる場合がございます。\n",
"def load_files(dir):\n",
" #Loading File\n",
" files = [f for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))]\n",
" print(f\"Files Loading... {len(files)} files.\")\n",
" for file in files:\n",
" print(f\"[+] Reading {file}\")\n",
" #TODO Check if the file already exists on \"loaded_dir\"\n",
"\n",
" file_path=dir+file\n",
" if file.endswith(\".pdf\"):\n",
" try:\n",
" loader = PyPDFLoader(file_path)\n",
" except Exception as e:\n",
" exc_type, exc_obj, exc_tb = sys.exc_info()\n",
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n",
" print(f\"[!] Error on reading PDF. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n",
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n",
" continue\n",
" elif file.endswith(\".doc\") or file.endswith(\".docx\") :\n",
" try:\n",
" loader = Docx2txtLoader(file_path)\n",
" except Exception as e:\n",
" exc_type, exc_obj, exc_tb = sys.exc_info()\n",
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n",
" print(f\"[!] Error on reading Doc. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n",
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n",
" continue\n",
" elif file.endswith(\".html\") or file.endswith(\".htm\") :\n",
" try:\n",
" loader = BSHTMLLoader(file_path)\n",
" except Exception as e:\n",
" exc_type, exc_obj, exc_tb = sys.exc_info()\n",
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n",
" print(f\"[!] Error on reading Doc. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n",
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n",
" continue\n",
" else:\n",
" try:\n",
" loader=TextLoader(file_path)\n",
" except Exception as e:\n",
" exc_type, exc_obj, exc_tb = sys.exc_info()\n",
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n",
" print(f\"[!] Error on reading file. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n",
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n",
" continue\n",
"\n",
" try:\n",
" store_vector(loader, dir)\n",
" os.rename(file_path, dir[:-1] + LOADED_DIR_SUFFIX + file)\n",
" except Exception as e:\n",
" exc_type, exc_obj, exc_tb = sys.exc_info()\n",
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n",
" print(f\"[!] Error on Storing Vector. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n",
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n",
" continue\n",
"\n",
"def store_vector(loader, dir):\n",
" documents = loader.load()\n",
" text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap_size)\n",
" docs = text_splitter.split_documents(documents)\n",
" print(len(docs))\n",
" embeddings = OpenAIEmbeddings()\n",
" faiss_db_path=dir[:-1] + DB_DIR_SUFFIX\n",
" new_db = FAISS.from_documents(docs, embeddings)\n",
" if os.path.exists(faiss_db_path + \"index.faiss\"):\n",
" existing_db=FAISS.load_local(faiss_db_path, embeddings)\n",
" new_db.merge_from(existing_db)\n",
" new_db.save_local(faiss_db_path)\n",
"\n",
"\n",
"load_files(data_path + \"/\")\n",
"print('\\033[92m\\u2714 Done')"
],
"metadata": {
"cellView": "form",
"id": "d-YybjlDDk-d"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 5.質問\n",
"question = \"\" #@param {type:\"string\"}\n",
"\n",
"def get_answer(dir, question, num_of_chunk_send_to_chatgpt=num_of_chunk_send_to_chatgpt):\n",
" embeddings = OpenAIEmbeddings()\n",
" db_path=dir[:-1] + DB_DIR_SUFFIX\n",
" vectordb = FAISS.load_local(db_path, embeddings)\n",
" retriever = vectordb.as_retriever(\n",
" search_type=\"similarity\",\n",
" search_kwargs={\"k\":num_of_chunk_send_to_chatgpt}\n",
" )\n",
" docs=retriever.get_relevant_documents(question)\n",
" try:\n",
" with get_openai_callback() as cb:\n",
" qa=RetrievalQA.from_chain_type(llm=OpenAI(), chain_type=chain_type, retriever=retriever, return_source_documents=True)\n",
" return qa({\"query\": question, \"cb\": cb})\n",
" except Exception as e:\n",
" if type(e).__name__ == \"InvalidRequestError\":\n",
" print(f\"Token Exceed Error. Reducing Number of Chunk to {num_of_chunk_send_to_chatgpt-1}\")\n",
" #InvalidRequestError:This model's maximum context length is 4097 tokens, however you requested 6072 tokens (5816 in your prompt; 256 for the completion). Please reduce your prompt; or completion length.\n",
" return get_answer(dir, question, num_of_chunk_send_to_chatgpt-1)\n",
" return str(type(e).__name__ + \":\" + str(e))\n",
"\n",
"\n",
"\n",
"answer=get_answer(data_path + \"/\", question)\n",
"\n",
"if 'result' in answer:\n",
" print(f\"\\033[92m\\u2714ANSWER:\\n\")\n",
" display(Markdown(f\"* {str(answer['result'])}\"))\n",
" print(answer['cb'])\n",
"else:\n",
" print(answer)\n",
" \n",
" \n",
"\n"
],
"metadata": {
"cellView": "form",
"id": "TkeE7L74GVfo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 5.1 回答に関連するドキュメント\n",
"for source_document in answer['source_documents']:\n",
" if source_document.metadata:\n",
" if 'source' in source_document.metadata and 'page' in source_document.metadata:\n",
" display(Markdown(f\"## {source_document.metadata['source'].replace(data_path + '/', '')} (p.{source_document.metadata['page']})\"))\n",
" elif 'source' in source_document.metadata:\n",
" display(Markdown(f\"## {source_document.metadata['source'].replace(data_path + '/', '')}\"))\n",
" display(Markdown(f\"* {source_document.page_content}\"))\n",
"print('\\033[92m\\u2714 Done')"
],
"metadata": {
"cellView": "form",
"id": "bHgoTtMaWRd-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 6.その他\n",
"* Google Driveと連携していない場合、Colabを終了するか、タイムアウトすると、学習したデータ(ベクター化した文書)は消えてしまう\n",
"* Google Driveと連携している場合は、Googleドライブ直下に「ShinoLang」フォルダができるので、次回再度連携させたときにそのデータベースを利用することができる\n",
"\n",
"#### 追加学習\n",
"* データフォルダにファイルをアップロードして、4.2を実行すれば追加学習ができる\n",
"\n",
"#### 追加質問\n",
"* 学習済みのデータについて質問をする場合は、5.を実施すればよい\n",
"\n",
"#### 技術的な詳細\n",
"利用しているモジュールの詳細:\n",
"* TextSplitter: RecursiveTextSplitter https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/recursive_text_splitter.html\n",
"* Embeddings: OpenAI https://python.langchain.com/en/latest/modules/models/text_embedding/examples/openai.html\n",
"* VectorDB: FAISS https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/faiss.html\n",
"* Chain: RetrievalQA: https://python.langchain.com/en/latest/modules/chains/index_examples/vector_db_qa.html\n",
"* LLM: ChatGPT\n",
"\n"
],
"metadata": {
"id": "KnH9vW3Y_cub"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment