Created
January 31, 2024 00:40
-
-
Save mahiya/08315cffb4b21f602a0d86698fef111e to your computer and use it in GitHub Desktop.
マルチスレッドで Azure OpenAI Service の Embeddings API を呼び出して埋め込みを取得する処理 (Python)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from utilities.parallel_embeddings import ParallelEmbeddingsClient | |
client = ParallelEmbeddingsClient("embeddings_config.json") | |
texts = [ f"埋め込みを取得する対象のテキスト" for i in range(0, 10) ] | |
embeds = client.get_embeds(texts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[ | |
{ | |
"name": "", | |
"key": "", | |
"deployName": "text-embedding-ada-002" | |
}, | |
{ | |
"name": "", | |
"key": "", | |
"deployName": "text-embedding-ada-002" | |
} | |
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from openai import AzureOpenAI | |
from concurrent.futures.thread import ThreadPoolExecutor | |
threads_per_client = 4 | |
embeddings_model = "text-embedding-ada-002" | |
class ParallelEmbeddingsClient: | |
def __init__(self, config_file_path="embeddings_config.json"): | |
with open(config_file_path, "r") as f: | |
configs = json.load(f) | |
self.clients = [ | |
AzureOpenAI( | |
azure_endpoint=f"https://{config['name']}.openai.azure.com/", | |
api_key=config["key"], | |
api_version=config["api_version"] if "api_version" in config else "2023-08-01-preview", | |
) | |
for config in configs | |
] | |
def __chunk(self, arr, size): | |
chunks = [[] for i in range(0, size)] | |
for i, item in enumerate(arr): | |
chunks[i % size].append(item) | |
return chunks | |
def __get_embeds(self, client, texts): | |
embeds = [] | |
for text in texts: | |
resp = client.embeddings.create(model=embeddings_model, input=text["text"]) | |
embed = resp.data[0].embedding | |
embeds.append({"order": text["order"], "embed": embed}) | |
return embeds | |
def __start_get_embeds_thread(self, client, texts): | |
sub_chunks = self.__chunk(texts, threads_per_client) | |
with ThreadPoolExecutor() as executor: | |
threads = [executor.submit(self.__get_embeds, client, sub_chunk) for sub_chunk in sub_chunks] | |
return [thread.result() for thread in threads] | |
def get_embeds(self, texts): | |
texts = [{"order": i, "text": text} for (i, text) in enumerate(texts)] | |
chunks = self.__chunk(texts, len(self.clients)) | |
with ThreadPoolExecutor() as executor: | |
threads = [ | |
executor.submit(self.__start_get_embeds_thread, self.clients[i], chunks[i]) | |
for i in range(0, len(chunks)) | |
] | |
embeds = [thread.result() for thread in threads] | |
embeds = sum(sum(embeds, []), []) | |
embeds.sort(key=lambda x: x["order"]) | |
return [embed["embed"] for embed in embeds] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment