Created
May 20, 2023 01:29
-
-
Save Sh1n0g1/eee5daef2ab9dacc14348c095620630e to your computer and use it in GitHub Desktop.
ShinoLang.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyOfCE3BG3+3zC/BLFpR/z4e", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Sh1n0g1/eee5daef2ab9dacc14348c095620630e/shinolang.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# ShinoLang\n", | |
"ドキュメントの学習および質問ができるシステム(LangChain + FLASS + ChatGPT)\n", | |
"読み込めるドキュメントは以下の通りである\n", | |
"* PDF\n", | |
"* Word文書 (.doc/.docx) ※ほとんど検証していない\n", | |
"* HTMLファイル (.htm/.html)\n", | |
"* その他のテキストファイル(.txt/\\*.\\*)\n", | |
"## 使い方(初回)\n", | |
"1. 上のメニューから`編集`から`出力をすべて消去`をクリックする\n", | |
"1. 真下の「1.設定」のフォームを埋める\n", | |
"1. 「5.質問」を埋める\n", | |
"1. 上のメニューから`ランタイム`の`すべてのセルを実行`をクリックする\n", | |
"1.「4.学習」で処理が止まるので、指定されたディレクトリに文書をアップロードする\n", | |
"1. 最終的に質問に対する答えが`5.1`で表示される\n", | |
"## 制限事項\n", | |
"* マルチバイトの文字を含むHTMLファイルはエラーになる場合があります。ブラウザの印刷機能でPDFに変換して読み込むのが良さそうです。\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "uVx7-oyKBBut" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 1.設定\n", | |
"import os\n", | |
"import sys\n", | |
"from IPython.display import clear_output\n", | |
"\n", | |
"#@markdown ### Googleドライブの連携\n", | |
"#@markdown 学習させた内容を次回以降も利用する場合は連携が必要となる\n", | |
"#@markdown 「ShinoLang」というディレクトリができる\n", | |
"\n", | |
"enable_google_drive = True #@param {type:\"boolean\"}\n", | |
"\n", | |
"#@markdown ### ChatGPTのAPIキー <font color='red'>*</font>\n", | |
"chatgpt_api_key = \"\" #@param {type:\"string\"}\n", | |
"\n", | |
"#@markdown ### ドキュメントを入れるフォルダ名 <font color='red'>*</font>\n", | |
"data_dir = \"documents\" #@param {type:\"string\"}\n", | |
"\n", | |
"#@markdown ### パラメータ\n", | |
"#@markdown よくわからない場合はデフォルトのままにする\n", | |
"\n", | |
"chunk_size = 500 #@param {type:\"integer\"}\n", | |
"chunk_overlap_size = 50 #@param {type:\"integer\"}\n", | |
"num_of_chunk_send_to_chatgpt = 10 #@param {type:\"integer\"}\n", | |
"chain_type = \"stuff\" #@param {type:\"string\"}\n", | |
"\n", | |
"\n", | |
"if chatgpt_api_key==\"\" or data_dir==\"\":\n", | |
" print(\"\\033[91m入力必須(*)の設定を入れてください。\")\n", | |
"\n", | |
" sys.exit()\n", | |
"\n", | |
"os.environ[\"OPENAI_API_KEY\"] = chatgpt_api_key\n", | |
"\n", | |
"clear_output()\n", | |
"print('\\033[92m\\u2714 Done')\n", | |
"\n" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "ZS73dOyGAleW" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "ugiW3SYv-q00" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title 2.必要モジュールのインストール&インポート\n", | |
"\n", | |
"!pip install pypdf\n", | |
"!pip install langchain\n", | |
"!pip install openai\n", | |
"!pip install faiss-cpu\n", | |
"!pip install tiktoken\n", | |
"\n", | |
"from IPython.display import display, Markdown\n", | |
"from google.colab import drive\n", | |
"from langchain.vectorstores import FAISS\n", | |
"from langchain.embeddings import OpenAIEmbeddings\n", | |
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n", | |
"from langchain.document_loaders import TextLoader\n", | |
"from langchain.document_loaders import PyPDFLoader\n", | |
"from langchain.document_loaders import Docx2txtLoader\n", | |
"from langchain.document_loaders import BSHTMLLoader\n", | |
"from langchain.chains import RetrievalQA\n", | |
"from langchain.llms import OpenAI\n", | |
"from langchain.callbacks import get_openai_callback\n", | |
"\n", | |
"clear_output()\n", | |
"print('\\033[92m\\u2714 Done')\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 3.パスの設定\n", | |
"GOOGLE_DRIVE_PATH='/content/gdrive/'\n", | |
"ROOT_DIR='ShinoLang/'\n", | |
"LOADED_DIR_SUFFIX=\"_loaded/\"\n", | |
"FAILED_DIR_SUFFIX=\"_failed/\"\n", | |
"DB_DIR_SUFFIX=\"_db/\"\n", | |
"\n", | |
"if enable_google_drive:\n", | |
" drive.mount(GOOGLE_DRIVE_PATH)\n", | |
" install_path=GOOGLE_DRIVE_PATH + 'MyDrive/' + ROOT_DIR\n", | |
"else:\n", | |
" install_path='/content/' + ROOT_DIR\n", | |
"if not os.path.exists(install_path):\n", | |
" os.mkdir(install_path)\n", | |
"\n", | |
"os.chdir(install_path)\n", | |
"if data_dir.endswith(\"/\"):\n", | |
" data_dir=data_dir[:-1]\n", | |
"\n", | |
"# Data Path\n", | |
"data_path=install_path + data_dir\n", | |
"if not os.path.exists(data_path):\n", | |
" os.mkdir(data_path)\n", | |
"\n", | |
"# Loaded Path\n", | |
"loaded_path=install_path + data_dir + LOADED_DIR_SUFFIX\n", | |
"if not os.path.exists(loaded_path):\n", | |
" os.mkdir(loaded_path)\n", | |
"\n", | |
"# Failed Path\n", | |
"failed_path=install_path + data_dir + FAILED_DIR_SUFFIX\n", | |
"if not os.path.exists(failed_path):\n", | |
" os.mkdir(failed_path)\n", | |
"\n", | |
"# DB Path\n", | |
"db_path=install_path + data_dir + DB_DIR_SUFFIX\n", | |
"if not os.path.exists(db_path):\n", | |
" os.mkdir(db_path)\n", | |
"\n", | |
"print(f\"Install Path: {install_path}\")\n", | |
"print(f\" Data Path: {data_path}\")\n", | |
"print(f\" Loaded Path: {loaded_path}\")\n", | |
"print(f\" Failed Path: {failed_path}\")\n", | |
"print(f\" DB Path: {db_path}\")\n", | |
"print('\\033[92m\\u2714 Done')\n" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "_zMo447r_H5C" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 4.1 学習(ファイルのアップロード)\n", | |
"#@markdown ファイルのアップロード\n", | |
"#@markdown * 学習させたいファイルを上記の `Data Path`にアップロードする。\n", | |
"#@markdown * ディレクトリが表示されない場合はディレクトリツリーから`更新`をする。\n", | |
"#@markdown * Googleドライブを利用している場合はそっちから直接アップロードしたほうが早い。\n", | |
"#@markdown * ファイルのアップロードが完了または学習をスキップするには[Enter]を押して続行する。\n", | |
"\n", | |
"input(\"Enterを入力してください\")\n", | |
"\n", | |
"\n" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "Ms0eaoPZU1aa" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 4.2 学習(実行)\n", | |
"#@markdown 学習を行います。大きい文書だと数分かかる場合がございます。\n", | |
"def load_files(dir):\n", | |
" #Loading File\n", | |
" files = [f for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))]\n", | |
" print(f\"Files Loading... {len(files)} files.\")\n", | |
" for file in files:\n", | |
" print(f\"[+] Reading {file}\")\n", | |
" #TODO Check if the file already exists on \"loaded_dir\"\n", | |
"\n", | |
" file_path=dir+file\n", | |
" if file.endswith(\".pdf\"):\n", | |
" try:\n", | |
" loader = PyPDFLoader(file_path)\n", | |
" except Exception as e:\n", | |
" exc_type, exc_obj, exc_tb = sys.exc_info()\n", | |
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n", | |
" print(f\"[!] Error on reading PDF. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n", | |
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n", | |
" continue\n", | |
" elif file.endswith(\".doc\") or file.endswith(\".docx\") :\n", | |
" try:\n", | |
" loader = Docx2txtLoader(file_path)\n", | |
" except Exception as e:\n", | |
" exc_type, exc_obj, exc_tb = sys.exc_info()\n", | |
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n", | |
" print(f\"[!] Error on reading Doc. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n", | |
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n", | |
" continue\n", | |
" elif file.endswith(\".html\") or file.endswith(\".htm\") :\n", | |
" try:\n", | |
" loader = BSHTMLLoader(file_path)\n", | |
" except Exception as e:\n", | |
" exc_type, exc_obj, exc_tb = sys.exc_info()\n", | |
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n", | |
" print(f\"[!] Error on reading Doc. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n", | |
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n", | |
" continue\n", | |
" else:\n", | |
" try:\n", | |
" loader=TextLoader(file_path)\n", | |
" except Exception as e:\n", | |
" exc_type, exc_obj, exc_tb = sys.exc_info()\n", | |
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n", | |
" print(f\"[!] Error on reading file. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n", | |
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n", | |
" continue\n", | |
"\n", | |
" try:\n", | |
" store_vector(loader, dir)\n", | |
" os.rename(file_path, dir[:-1] + LOADED_DIR_SUFFIX + file)\n", | |
" except Exception as e:\n", | |
" exc_type, exc_obj, exc_tb = sys.exc_info()\n", | |
" fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]\n", | |
" print(f\"[!] Error on Storing Vector. {e} {exc_type} ({fname}:{exc_tb.tb_lineno}\")\n", | |
" os.rename(file_path, dir[:-1] + FAILED_DIR_SUFFIX + file)\n", | |
" continue\n", | |
"\n", | |
"def store_vector(loader, dir):\n", | |
" documents = loader.load()\n", | |
" text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap_size)\n", | |
" docs = text_splitter.split_documents(documents)\n", | |
" print(len(docs))\n", | |
" embeddings = OpenAIEmbeddings()\n", | |
" faiss_db_path=dir[:-1] + DB_DIR_SUFFIX\n", | |
" new_db = FAISS.from_documents(docs, embeddings)\n", | |
" if os.path.exists(faiss_db_path + \"index.faiss\"):\n", | |
" existing_db=FAISS.load_local(faiss_db_path, embeddings)\n", | |
" new_db.merge_from(existing_db)\n", | |
" new_db.save_local(faiss_db_path)\n", | |
"\n", | |
"\n", | |
"load_files(data_path + \"/\")\n", | |
"print('\\033[92m\\u2714 Done')" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "d-YybjlDDk-d" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 5.質問\n", | |
"question = \"\" #@param {type:\"string\"}\n", | |
"\n", | |
"def get_answer(dir, question, num_of_chunk_send_to_chatgpt=num_of_chunk_send_to_chatgpt):\n", | |
" embeddings = OpenAIEmbeddings()\n", | |
" db_path=dir[:-1] + DB_DIR_SUFFIX\n", | |
" vectordb = FAISS.load_local(db_path, embeddings)\n", | |
" retriever = vectordb.as_retriever(\n", | |
" search_type=\"similarity\",\n", | |
" search_kwargs={\"k\":num_of_chunk_send_to_chatgpt}\n", | |
" )\n", | |
" docs=retriever.get_relevant_documents(question)\n", | |
" try:\n", | |
" with get_openai_callback() as cb:\n", | |
" qa=RetrievalQA.from_chain_type(llm=OpenAI(), chain_type=chain_type, retriever=retriever, return_source_documents=True)\n", | |
" return qa({\"query\": question, \"cb\": cb})\n", | |
" except Exception as e:\n", | |
" if type(e).__name__ == \"InvalidRequestError\":\n", | |
" print(f\"Token Exceed Error. Reducing Number of Chunk to {num_of_chunk_send_to_chatgpt-1}\")\n", | |
" #InvalidRequestError:This model's maximum context length is 4097 tokens, however you requested 6072 tokens (5816 in your prompt; 256 for the completion). Please reduce your prompt; or completion length.\n", | |
" return get_answer(dir, question, num_of_chunk_send_to_chatgpt-1)\n", | |
" return str(type(e).__name__ + \":\" + str(e))\n", | |
"\n", | |
"\n", | |
"\n", | |
"answer=get_answer(data_path + \"/\", question)\n", | |
"\n", | |
"if 'result' in answer:\n", | |
" print(f\"\\033[92m\\u2714ANSWER:\\n\")\n", | |
" display(Markdown(f\"* {str(answer['result'])}\"))\n", | |
" print(answer['cb'])\n", | |
"else:\n", | |
" print(answer)\n", | |
" \n", | |
" \n", | |
"\n" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "TkeE7L74GVfo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 5.1 回答に関連するドキュメント\n", | |
"for source_document in answer['source_documents']:\n", | |
" if source_document.metadata:\n", | |
" if 'source' in source_document.metadata and 'page' in source_document.metadata:\n", | |
" display(Markdown(f\"## {source_document.metadata['source'].replace(data_path + '/', '')} (p.{source_document.metadata['page']})\"))\n", | |
" elif 'source' in source_document.metadata:\n", | |
" display(Markdown(f\"## {source_document.metadata['source'].replace(data_path + '/', '')}\"))\n", | |
" display(Markdown(f\"* {source_document.page_content}\"))\n", | |
"print('\\033[92m\\u2714 Done')" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "bHgoTtMaWRd-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 6.その他\n", | |
"* Google Driveと連携していない場合、Colabを終了するか、タイムアウトすると、学習したデータ(ベクター化した文書)は消えてしまう\n", | |
"* Google Driveと連携している場合は、Googleドライブ直下に「ShinoLang」フォルダができるので、次回再度連携させたときにそのデータベースを利用することができる\n", | |
"\n", | |
"#### 追加学習\n", | |
"* データフォルダにファイルをアップロードして、4.2を実行すれば追加学習ができる\n", | |
"\n", | |
"#### 追加質問\n", | |
"* 学習済みのデータについて質問をする場合は、5.を実施すればよい\n", | |
"\n", | |
"#### 技術的な詳細\n", | |
"利用しているモジュールの詳細:\n", | |
"* TextSplitter: RecursiveTextSplitter https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/recursive_text_splitter.html\n", | |
"* Embeddings: OpenAI https://python.langchain.com/en/latest/modules/models/text_embedding/examples/openai.html\n", | |
"* VectorDB: FAISS https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/faiss.html\n", | |
"* Chain: RetrievalQA: https://python.langchain.com/en/latest/modules/chains/index_examples/vector_db_qa.html\n", | |
"* LLM: ChatGPT\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "KnH9vW3Y_cub" | |
} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment