Last active
March 14, 2024 06:48
-
-
Save gh640/e77496940e603618ed305ec7dddab1ea to your computer and use it in GitHub Desktop.
サンプル: OpenAI Assistants stream API を使う
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""OpenAI の Assistant stream API を使う | |
Usage: | |
python -m pip install 'openai==1.14.0' | |
export OPENAI_API_KEY='...' | |
python openai_assistant_stream.py | |
See: https://platform.openai.com/docs/assistants/overview?context=with-streaming | |
""" | |
from openai import AssistantEventHandler | |
import os | |
import sys | |
from pathlib import Path | |
from typing import override | |
from openai import OpenAI | |
ASSISTANT_ID = "..." | |
MESSAGE = "..." | |
FILE_OUT = Path(__file__).resolve().parent / "out.txt" | |
def main(): | |
if "OPENAI_API_KEY" not in os.environ: | |
sys.exit("Environment variable `OPENAI_API_KEY` is required.") | |
client = OpenAI() | |
assistant = client.beta.assistants.retrieve(ASSISTANT_ID) | |
thread = client.beta.threads.create() | |
client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=MESSAGE, | |
) | |
with client.beta.threads.runs.create_and_stream( | |
thread_id=thread.id, | |
assistant_id=assistant.id, | |
event_handler=EventHandler(), | |
) as stream: | |
stream.until_done() | |
final_messages = stream.get_final_messages() | |
run = stream.get_final_run() | |
response = final_messages[0].content[0].text.value | |
eprint(f"Assistant ID: {run.assistant_id=}") | |
eprint(f"Run ID: {run.id=}") | |
eprint(f"Thread ID: {run.thread_id=}") | |
eprint(f"Model: {run.model=}") | |
eprint(f"Status: {run.status=}") | |
eprint(f"Usage: {run.usage=}") | |
FILE_OUT.write_text(response) | |
eprint(f"File saved: {FILE_OUT}") | |
def eprint(msg): | |
print(msg, file=sys.stderr) | |
class EventHandler(AssistantEventHandler): | |
@override | |
def on_text_created(self, text) -> None: | |
print("\nassistant > ", end="", flush=True) | |
@override | |
def on_text_delta(self, delta, snapshot): | |
print(delta.value, end="", flush=True) | |
@override | |
def on_tool_call_created(self, tool_call): | |
print(f"\nassistant > {tool_call.type}\n", flush=True) | |
@override | |
def on_tool_call_delta(self, delta, snapshot): | |
if delta.type == 'code_interpreter': | |
if delta.code_interpreter.input: | |
print(delta.code_interpreter.input, end="", flush=True) | |
if delta.code_interpreter.outputs: | |
print("\n\noutput >", flush=True) | |
for output in delta.code_interpreter.outputs: | |
if output.type == "logs": | |
print(f"\n{output.logs}", flush=True) | |
@override | |
def on_end(self): | |
print() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
requirements.txt