Created
June 11, 2024 08:43
-
-
Save habedi/92522e891ada01aa9da4be86c8028a8a to your computer and use it in GitHub Desktop.
Example OpenAI connector code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
import numpy as np | |
import openai | |
from numpy import ndarray, dtype | |
class LanguageModels: | |
"""Enum for the supported language models.""" | |
OPENAI_GPT35TURBO = 'gpt-3.5-turbo' | |
OPENAI_GPT4O = 'gpt-4o' | |
class EmbeddingModels: | |
"""Enum for the supported text embedding models.""" | |
OPENAPI_EMS = 'text-embedding-3-small' | |
OPENAPI_EML = 'text-embedding-3-large' | |
class OpenAIConnector: | |
def __init__(self, api_key: str): | |
""" | |
Initialize the OpenAIConnector with the provided API key. | |
Args: | |
api_key (str): The API key for authenticating with OpenAI. | |
""" | |
self.client = openai.Client(api_key=api_key) | |
def embed(self, documents: list[str], | |
embedding_model: str = EmbeddingModels.OPENAPI_EMS) -> list[ndarray[Any, dtype[Any]]]: | |
""" | |
Generate embeddings for a list of documents using the specified embedding model. | |
Args: | |
documents (list[str]): A list of documents to be embedded. | |
embedding_model (str): The embedding model to use (default is 'text-embedding-3-small'). | |
Returns: | |
list[ndarray[Any, dtype[Any]]]: A list of numpy arrays containing the embeddings. | |
""" | |
embeddings = self.client.embeddings.create(input=documents, model=embedding_model) | |
return [np.array(d.embedding) for d in embeddings.data] | |
def chat(self, prompt: str, model: str = LanguageModels.OPENAI_GPT35TURBO, | |
temperature: float = 0.0, max_tokens: int = 100) -> str: | |
""" | |
Generate a chat completion for the given prompt using the specified language model. | |
Args: | |
prompt (str): The prompt to send to the language model. | |
model (str): The language model to use (default is 'gpt-3.5-turbo'). | |
temperature (float): The sampling temperature (default is 0.0). | |
max_tokens (int): The maximum number of tokens to generate (default is 100). | |
Returns: | |
str: The generated completion text. | |
""" | |
completion = self.client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "user", "content": prompt}, | |
], | |
max_tokens=max_tokens, | |
temperature=temperature | |
) | |
return completion.choices[0].message.content |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment