Last active
July 27, 2023 08:35
-
-
Save kkdai/a99e1a78906a4c1eaac1fea1f440aa98 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import aiohttp | |
from fastapi import Request, FastAPI, HTTPException | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import ConversationChain | |
from langchain.memory import ConversationBufferWindowMemory | |
from linebot import ( | |
AsyncLineBotApi, WebhookParser | |
) | |
from linebot.aiohttp_async_http_client import AiohttpAsyncHttpClient | |
from linebot.exceptions import ( | |
InvalidSignatureError | |
) | |
from linebot.models import ( | |
MessageEvent, TextMessage, TextSendMessage, | |
) | |
import os | |
import openai | |
from dotenv import load_dotenv, find_dotenv | |
_ = load_dotenv(find_dotenv()) # read local .env file | |
# get channel_secret and channel_access_token from your environment variable | |
channel_secret = os.getenv('ChannelSecret', None) | |
channel_access_token = os.getenv('ChannelAccessToken', None) | |
if channel_secret is None: | |
print('Specify LINE_CHANNEL_SECRET as environment variable.') | |
sys.exit(1) | |
if channel_access_token is None: | |
print('Specify LINE_CHANNEL_ACCESS_TOKEN as environment variable.') | |
sys.exit(1) | |
app = FastAPI() | |
session = aiohttp.ClientSession() | |
async_http_client = AiohttpAsyncHttpClient(session) | |
line_bot_api = AsyncLineBotApi(channel_access_token, async_http_client) | |
parser = WebhookParser(channel_secret) | |
# Langchain 串接 OpenAI ,這裡 model 可以先選 gpt-3.5-turbo | |
llm = ChatOpenAI(temperature=0.9, model='gpt-3.5-turbo') | |
# 透過 ConversationBufferWindowMemory 快速打造一個具有「記憶力」的聊天機器人,可以記住至少五回。 | |
# 通常來說 5 回還蠻夠的 | |
memory = ConversationBufferWindowMemory(k=5) | |
conversation = ConversationChain( | |
llm=llm, | |
memory=memory, | |
verbose=False | |
) | |
@app.post("/callback") | |
async def handle_callback(request: Request): | |
signature = request.headers['X-Line-Signature'] | |
# get request body as text | |
body = await request.body() | |
body = body.decode() | |
try: | |
events = parser.parse(body, signature) | |
except InvalidSignatureError: | |
raise HTTPException(status_code=400, detail="Invalid signature") | |
for event in events: | |
if not isinstance(event, MessageEvent): | |
continue | |
if not isinstance(event.message, TextMessage): | |
continue | |
# 將使用者傳來的訊息 event.message.text 當成輸入,等 LangChain 傳回結果。 | |
ret = conversation.predict(input=event.message.text) | |
await line_bot_api.reply_message( | |
event.reply_token, | |
TextSendMessage(text=ret) | |
) | |
return 'OK' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment