Last active
January 4, 2024 17:27
-
-
Save daviddai-evenup/230d2a04cab2948ba7274c53d0b1fef2 to your computer and use it in GitHub Desktop.
capped parallel requests to openai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import json | |
import os | |
import logging | |
from openai import AsyncOpenAI | |
from tenacity import ( | |
retry, | |
stop_after_attempt, | |
wait_random_exponential, | |
after_log, | |
) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
client = AsyncOpenAI( | |
api_key=os.environ["OPENAI_KEY"], | |
organization=os.environ["OPENAI_ORG_ID"], | |
) | |
@retry( | |
wait=wait_random_exponential(min=1, max=60), | |
stop=stop_after_attempt(6), | |
after=after_log(logger, logging.INFO), | |
) | |
async def chat_completion(model: str, prompted_message: str): | |
messages = [ | |
{ | |
"role": "user", | |
"content": prompted_message, | |
} | |
] | |
return await client.chat.completions.create( | |
model=model, | |
messages=messages, | |
seed=112892, | |
temperature=0, | |
) | |
async def submit_tasks(tasks: list[asyncio.Task], max_concurrent_tasks): | |
semaphore = asyncio.Semaphore(max_concurrent_tasks) | |
async def sem_task(task): | |
async with semaphore: | |
return await task | |
return await asyncio.gather(*(sem_task(task) for task in tasks)) | |
def parse_response(response: str): | |
# Some post-processing logic | |
... | |
try: | |
return json.loads(response) | |
except json.decoder.JSONDecodeError: | |
return [] | |
def create_prompt(input: str): | |
# some prompt template | |
... | |
def main( | |
inputs: list[str], model: str, max_concurrent_tasks: int = 5 | |
): | |
tasks = [ | |
chat_completion(model=model, prompted_message=create_prompt(input)) | |
for input in inputs | |
] | |
responses = asyncio.run(submit_tasks(tasks, max_concurrent_tasks)) | |
return [ | |
parse_response(response.choices[0].message.content) for response in responses | |
] | |
if __name__ == "__main__": | |
inputs = ["input 1", "input 2", "input 3"] | |
generations = main(inputs, max_concurrent_tasks = 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment