Skip to content

Instantly share code, notes, and snippets.

@koganei
Created May 5, 2023 21:03
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koganei/ead7d7ad4cb2c751ae16240b1575fec5 to your computer and use it in GitHub Desktop.
Save koganei/ead7d7ad4cb2c751ae16240b1575fec5 to your computer and use it in GitHub Desktop.
Websocket Support for Haystack Agent
import asyncio
import sys
import threading
from typing import Any
import json
import os
import logging
from fastapi import FastAPI, WebSocket, BackgroundTasks, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from functools import partial
from connection_manager import ConnectionManager
app = FastAPI()
# Add CORS headers to FAstaPI
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
connection_manager = ConnectionManager()
openaikey = os.environ["OPENAI_API_KEY"]
activeAgents = {}
logging.basicConfig(
format="%(levelname)s - %(name)s -- %(message)s\n\n", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)
def create_agent(id, personality="hr", openai_model="gpt-3.5-turbo", websocket=None):
from tools.haystack_memory_custom.prompt_templates import get_memory_template
from tools.haystack_memory_custom.utils import MemoryUtils
from haystack.nodes import PromptNode
from haystack_custom_agent.agent import CustomAgent
import load_tools
memory_database = []
prompt_node = PromptNode(model_name_or_path=openai_model, api_key=openaikey, max_length=2000,
stop_words=["Observation:"])
agent = CustomAgent(prompt_node=prompt_node,
prompt_template=get_memory_template(personality, memory=memory_database))
load_tools.load_tools(agent, memory_database)
load_tools.load_post_processing_pipeline(agent, memory_database)
agent_data = {
'personality': personality,
'agent': MemoryUtils(memory_database=memory_database, agent=agent),
'memory_database': memory_database,
"node": prompt_node,
"unwrapped_agent": agent
}
activeAgents[id] = agent_data
current_tool = ""
def on_tool_start(cm, tool_input, tool, **kwargs):
current_tool = tool.name
print("on tool start")
cm.put_message(websocket, json.dumps({
"type": "tool_start",
"tool": tool.name,
"tool_input": tool_input
}))
def on_tool_finish(cm, tool_result, **kwargs):
print("on tool finish:" + json.dumps({
"type": "tool_finish",
"tool": current_tool,
"tool_input": tool_result
}))
cm.put_message(websocket, json.dumps({
"type": "tool_finish",
"tool": current_tool,
"tool_input": tool_result
}))
agent.callback_manager.on_tool_start += partial(on_tool_start, connection_manager)
agent.callback_manager.on_tool_finish += partial(on_tool_finish, connection_manager)
return agent_data["agent"]
def main():
from tools.haystack_memory_custom.prompt_templates import get_memory_template
def inference_function(agent_id, query, personality, openai_model, websocket):
if agent_id in activeAgents:
connection_manager.put_message(websocket, "{\"log\": \"Using existing agent: " + str(agent_id) + "\"}")
agentObject = activeAgents[agent_id]
agent = agentObject["agent"]
agentObject["unwrapped_agent"].prompt_template = agentObject["node"].get_prompt_template(
get_memory_template(agentObject["personality"],
memory=agentObject["memory_database"])
)
else:
connection_manager.put_message(websocket, "{\"log\": \"Creating new agent: " + str(agent_id) +
" with model: " + openai_model + "\"}")
agent = create_agent(agent_id, personality, openai_model, websocket)
result = agent.chat(query)
send_text = {
"result": result
}
connection_manager.put_message(websocket, json.dumps(send_text))
connection_manager.put_message(websocket, "close")
def start_server():
print("starting server")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, background_tasks: BackgroundTasks):
await connection_manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await websocket.send_text("{\"log\":\"Starting Inference\"}")
try:
json_data = json.loads(data)
agent_id = json_data["agent_id"]
query = json_data["query"]
t = threading.Thread(target=inference_function, args=(agent_id, query, "hr", "gpt-4", websocket,))
t.start()
while True:
if not connection_manager.message_queues[websocket].empty():
msg_to_send = connection_manager.get_message(websocket)
if msg_to_send == "close":
break
await connection_manager.send_message(websocket, msg_to_send)
else:
await asyncio.sleep(0.1) # Sleep for a short duration before checking the queue again
except Exception as e:
error = {
"error": "Something went wrong",
"message": str(e)
}
await websocket.send_text(json.dumps(error))
except WebSocketDisconnect:
print("WebSocket disconnected")
connection_manager.disconnect(websocket)
# Clean up agent
def run_in_command_line():
result = inference_function(123,
# "Hi! Can you please update the timecard to add 4 hours to Peter on last friday and a comment to Peter's Timecard Exception Task that he checked in with me this morning?"
"Hi! Can you tell me about my current tasks?", "hr", "gpt-4")
print('\n\n\n\n==============RESULT TO THE FRONTEND:\n\n' +
result + '\n\n\n\n=====================')
try:
arg_command = sys.argv[1]
except IndexError:
arg_command = ""
if arg_command == "--cli":
run_in_command_line()
else:
start_server()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment