Skip to content

Instantly share code, notes, and snippets.

@mahiya
Created January 31, 2024 00:40
Show Gist options
  • Save mahiya/08315cffb4b21f602a0d86698fef111e to your computer and use it in GitHub Desktop.
Save mahiya/08315cffb4b21f602a0d86698fef111e to your computer and use it in GitHub Desktop.
マルチスレッドで Azure OpenAI Service の Embeddings API を呼び出して埋め込みを取得する処理 (Python)
from utilities.parallel_embeddings import ParallelEmbeddingsClient
client = ParallelEmbeddingsClient("embeddings_config.json")
texts = [ f"埋め込みを取得する対象のテキスト" for i in range(0, 10) ]
embeds = client.get_embeds(texts)
[
{
"name": "",
"key": "",
"deployName": "text-embedding-ada-002"
},
{
"name": "",
"key": "",
"deployName": "text-embedding-ada-002"
}
]
import json
from openai import AzureOpenAI
from concurrent.futures.thread import ThreadPoolExecutor
threads_per_client = 4
embeddings_model = "text-embedding-ada-002"
class ParallelEmbeddingsClient:
def __init__(self, config_file_path="embeddings_config.json"):
with open(config_file_path, "r") as f:
configs = json.load(f)
self.clients = [
AzureOpenAI(
azure_endpoint=f"https://{config['name']}.openai.azure.com/",
api_key=config["key"],
api_version=config["api_version"] if "api_version" in config else "2023-08-01-preview",
)
for config in configs
]
def __chunk(self, arr, size):
chunks = [[] for i in range(0, size)]
for i, item in enumerate(arr):
chunks[i % size].append(item)
return chunks
def __get_embeds(self, client, texts):
embeds = []
for text in texts:
resp = client.embeddings.create(model=embeddings_model, input=text["text"])
embed = resp.data[0].embedding
embeds.append({"order": text["order"], "embed": embed})
return embeds
def __start_get_embeds_thread(self, client, texts):
sub_chunks = self.__chunk(texts, threads_per_client)
with ThreadPoolExecutor() as executor:
threads = [executor.submit(self.__get_embeds, client, sub_chunk) for sub_chunk in sub_chunks]
return [thread.result() for thread in threads]
def get_embeds(self, texts):
texts = [{"order": i, "text": text} for (i, text) in enumerate(texts)]
chunks = self.__chunk(texts, len(self.clients))
with ThreadPoolExecutor() as executor:
threads = [
executor.submit(self.__start_get_embeds_thread, self.clients[i], chunks[i])
for i in range(0, len(chunks))
]
embeds = [thread.result() for thread in threads]
embeds = sum(sum(embeds, []), [])
embeds.sort(key=lambda x: x["order"])
return [embed["embed"] for embed in embeds]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment