Skip to content

Instantly share code, notes, and snippets.

@gh640
Last active March 14, 2024 06:48
Show Gist options
  • Save gh640/e77496940e603618ed305ec7dddab1ea to your computer and use it in GitHub Desktop.
Save gh640/e77496940e603618ed305ec7dddab1ea to your computer and use it in GitHub Desktop.
サンプル: OpenAI Assistants stream API を使う
"""OpenAI の Assistant stream API を使う
Usage:
python -m pip install 'openai==1.14.0'
export OPENAI_API_KEY='...'
python openai_assistant_stream.py
See: https://platform.openai.com/docs/assistants/overview?context=with-streaming
"""
from openai import AssistantEventHandler
import os
import sys
from pathlib import Path
from typing import override
from openai import OpenAI
ASSISTANT_ID = "..."
MESSAGE = "..."
FILE_OUT = Path(__file__).resolve().parent / "out.txt"
def main():
if "OPENAI_API_KEY" not in os.environ:
sys.exit("Environment variable `OPENAI_API_KEY` is required.")
client = OpenAI()
assistant = client.beta.assistants.retrieve(ASSISTANT_ID)
thread = client.beta.threads.create()
client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=MESSAGE,
)
with client.beta.threads.runs.create_and_stream(
thread_id=thread.id,
assistant_id=assistant.id,
event_handler=EventHandler(),
) as stream:
stream.until_done()
final_messages = stream.get_final_messages()
run = stream.get_final_run()
response = final_messages[0].content[0].text.value
eprint(f"Assistant ID: {run.assistant_id=}")
eprint(f"Run ID: {run.id=}")
eprint(f"Thread ID: {run.thread_id=}")
eprint(f"Model: {run.model=}")
eprint(f"Status: {run.status=}")
eprint(f"Usage: {run.usage=}")
FILE_OUT.write_text(response)
eprint(f"File saved: {FILE_OUT}")
def eprint(msg):
print(msg, file=sys.stderr)
class EventHandler(AssistantEventHandler):
@override
def on_text_created(self, text) -> None:
print("\nassistant > ", end="", flush=True)
@override
def on_text_delta(self, delta, snapshot):
print(delta.value, end="", flush=True)
@override
def on_tool_call_created(self, tool_call):
print(f"\nassistant > {tool_call.type}\n", flush=True)
@override
def on_tool_call_delta(self, delta, snapshot):
if delta.type == 'code_interpreter':
if delta.code_interpreter.input:
print(delta.code_interpreter.input, end="", flush=True)
if delta.code_interpreter.outputs:
print("\n\noutput >", flush=True)
for output in delta.code_interpreter.outputs:
if output.type == "logs":
print(f"\n{output.logs}", flush=True)
@override
def on_end(self):
print()
if __name__ == "__main__":
main()
@gh640
Copy link
Author

gh640 commented Mar 14, 2024

requirements.txt

openai==1.14.0

@gh640
Copy link
Author

gh640 commented Mar 14, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment