Skip to content

Instantly share code, notes, and snippets.

Created July 15, 2023 19:47
Show Gist options
  • Save monk1337/6feb740d4001f3d1e2e8470fcc702168 to your computer and use it in GitHub Desktop.
Save monk1337/6feb740d4001f3d1e2e8470fcc702168 to your computer and use it in GitHub Desktop.
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Union, Dict
import tenacity
class Model(metaclass=ABCMeta):
Abstract base class for a large language model(llm).
name : str
The name of the language model.
description : str
A brief description of the language model.
__init__(api_key, model, api_wait=None, api_retry=None, **kwargs) -> None:
Initializes the Model class with the required parameters and verifies the model is supported by the endpoint.
supported_models() -> List[str]:
Abstract method to return a list of supported models for the endpoint.
_verify_model() -> None:
Abstract method to verify if the model is supported by the endpoint.
set_key(api_key) -> None:
Abstract method to set the endpoint API key.
set_model(model) -> None:
Abstract method to set the model name for the endpoint.
get_description() -> str:
Abstract method to get the model description.
get_endpoint() -> str:
Abstract method to get the model endpoint.
get_parameters() -> Dict[str, str]:
Abstract method to get the model parameters.
run(prompts) -> List[str]:
Abstract method to run the language model on the given list of prompts and return the list of responses.
model_output(response) -> Any:
Abstract method to get the model output from the response.
_retry_decorator() -> tenacity.Retry:
Decorator function for retrying API requests if they fail.
execute_with_retry(*args, **kwargs) -> List[str]:
Decorated version of the `run` method with the retry logic.
>>> class MyModel(Model):
... def __init__(self, api_key, model, api_wait=None, api_retry=None, **kwargs):
... super().__init__(api_key, model, api_wait, api_retry, **kwargs)
... @classmethod
... def supported_models(cls) -> List[str]:
... return ['gpt', 'davinci']
... def _verify_model(self):
... assert self.model in self.supported_models(), f"{self.model} is not a supported model"
... def set_key(self, api_key: str):
... self.api_key = api_key
... def set_model(self, model: str):
... self.model = model
... def get_description(self) -> str:
... return self.description
... def get_endpoint(self) -> str:
... return ""
... def get_parameters(self) -> Dict[str, str]:
... return {"model": self.model, "prompt": "", "temperature": "0.7"}
... def run(self, prompts: List[str]) -> List[str]:
... # Send the request to OpenAI's API
... response =
... self.get_endpoint(),
... headers={
... "Content-Type": "application/json",
... "Authorization": f"Bearer {self.api_key}",
... },
... json=self.get_parameters(),
... )
... # Get the output from the response
... output = self.model_output(response)
... # Return the output
... return output
... def model_output(self, response):
... return response.json()['choices'][0]['text']
This class is an abstract base class for creating large language model classes.
It provides common methods and attributes that can be used by different llms
classes to make calls to llms API more streamlined.
name = ""
description = ""
def __init__(
api_key: str,
model: str,
api_wait: int = 60,
api_retry: int = 6,
Initializes the Model class with the required parameters and verifies the model is supported by the endpoint.
api_key : str
The API key if needed for the endpoint.
model : str
The name of the LLM model to use for the endpoint.
api_wait : int, optional
Maximum wait time for an API request before retrying (in seconds), by default 60.
api_retry : int, optional
Number of times to retry an API request before failing, by default 6.
**kwargs : dict
Additional arguments to be passed to the Model API call.
This method initializes the Model class with the required parameters and verifies that the given model is supported by the endpoint. It sets the values of `api_key`, `model`, `api_wait`, and `api_retry` attributes of the class.
>>> my_model = MyModel(api_key="my_api_key", model="davinci")
self.api_key = api_key
self.model = model
self.api_wait = api_wait
self.api_retry = api_retry
def supported_models(self):
Get a list of supported models for the endpoint.
A list of supported models for the endpoint.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def supported_models(self):
... return ['gpt', 'davinci']
raise NotImplementedError
def _verify_model(self):
Verify the model is supported by the endpoint.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def _verify_model(self):
... assert self.model in self.supported_models(), f"{self.model} is not a supported model"
raise NotImplementedError
def set_key(self, api_key: str):
Set endpoint API key if needed.
api_key : str
The API key to set.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def set_key(self, api_key: str):
... self.api_key = api_key
raise NotImplementedError
def set_model(self, model: str):
Set model name for the endpoint.
model : str
The model name to set.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def set_model(self, model: str):
... self.model = model
raise NotImplementedError
def get_description(self) -> str:
Get model description.
A string containing the model description.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def get_description(self) -> str:
... return self.description
raise NotImplementedError
def get_endpoint(self) -> str:
Get model endpoint.
A string containing the model endpoint.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def get_endpoint(self) -> str:
... return ""
raise NotImplementedError
def get_parameters(self) -> Dict[str, str]:
Get model parameters.
Dict[str, str]
A dictionary containing the model parameters.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def get_parameters(self) -> Dict[str, str]:
... return {"model": self.model, "prompt": "", "temperature": "0.7"}
raise NotImplementedError
def run(self, prompts: List[str]) -> List[str]:
Run the LLM on the given prompt list.
prompts : List[str]
A list of prompts to run on the LLM.
A list of responses from the LLM.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def run(self, prompts: List[str]) -> List[str]:
... # Send the request to OpenAI's API
... response =
... self.get_endpoint(),
... headers={
... "Content-Type": "application/json",
... "Authorization": f"Bearer {self.api_key}",
... },
... json=self.get_parameters(),
... )
... # Get the output from the response
... output = self.model_output(response)
... # Return the output
... return output
raise NotImplementedError
def model_output(self, response):
Get the model output from the response.
response : requests.Response
The response from the API call.
This method is an abstract method and must be implemented in the derived classes.
>>> class MyModel(Model):
... def model_output(self, response):
... return response.json()['choices'][0]['text']
raise NotImplementedError
def _retry_decorator(self):
Decorator function for retrying API requests if they fail.
A decorator function for retrying API requests.
This method is a decorator function for retrying API requests using tenacity.
return tenacity.retry(
multiplier=0.3, exp_base=3, max=self.api_wait
def execute_with_retry(self, *args, **kwargs):
Decorated version of the run method with the retry logic.
*args : tuple
A tuple of arguments to pass to the `run` method.
**kwargs : dict
A dictionary of keyword arguments to pass to the `run` method.
The output of the `run` method.
This method is a decorated version of the `run` method with the retry logic.
decorated_run = self._retry_decorator()(
return decorated_run(*args, **kwargs)
from typing import Dict, List, Optional, Tuple, Union
import openai
import json
import tiktoken
from promptify.parser.parser import Parser
from promptify.models.text2text.api.base_model import Model
class OpenAI(Model):
name = "OpenAI"
description = "OpenAI API for text completion using various models"
"completion_models": set(
"chat_models": set(
def __init__(
api_key: str,
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
top_p: float = 1,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
presence_penalty: float = 0,
frequency_penalty: float = 0,
logit_bias: Optional[Dict[str, int]] = None,
request_timeout: Union[float, Tuple[float, float]] = None,
max_completion_length: int = 20,
super().__init__(api_key, model, api_wait, api_retry)
self.temperature = temperature
self.top_p = top_p
self.n = n
self.stop = stop
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias or {}
self.request_timeout = request_timeout
self.max_completion_length = max_completion_length
self.parameters = self.get_parameters()
def set_key(self, api_key: str):
self._openai = openai
self._openai.api_key = api_key
def _verify_model(self):
model_type = (
if self.model in self.SUPPORTED_MODELS["completion_models"]
else "chat_models"
if self.model not in self.SUPPORTED_MODELS[model_type]:
raise ValueError(f"Unsupported model: {self.model}")
self.model_type = model_type
def _initialize_encoder(self):
self.encoder = tiktoken.encoding_for_model(self.model)
def _initialize_parser(self):
self.parser = Parser()
def set_model(self, model: str):
self.model = model
def supported_models(self):
return list(itertools.chain(*self.SUPPORTED_MODELS.values()))
def get_parameters(self):
return {
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"request_timeout": self.request_timeout,
def get_description(self):
return self.description
def get_endpoint(self):
model = self._openai.Model.retrieve(self.model)
return model["id"]
def run(self, prompt: str):
if self.model_type == "chat_models":
return self._chat_api(prompt)
elif self.model_type == "completion_models":
return self._completion_api(prompt)
# def model_output(self, response):
# return self.model_output_formatted(response, self.max_completion_length)
def _completion_api(self, prompt: str):
self.parameters["prompt"] = prompt
self.parameters["max_tokens"] = self._calculate_max_tokens(prompt)
response = self._openai.Completion.create(
return response
def _chat_api(self, prompt: str):
prompt_template = [
{"role": "system", "content": "you are a helpful assistant."},
{"role": "user", "content": prompt},
self.parameters["max_tokens"] = self._calculate_max_tokens(prompt_template)
self.parameters["messages"] = prompt_template
response = self._openai.ChatCompletion.create(
return response
def _calculate_max_tokens(self, prompt: str) -> int:
prompt_tokens = len(self.encoder.encode(str(prompt)))
max_tokens = self._default_max_tokens(self.model) - prompt_tokens
return max_tokens
def _default_max_tokens(self, model_name: str) -> int:
token_dict = {
"text-babbage-001": 2040,
"text-ada-001": 2048,
"ada": 2048,
"babbage": 2048,
"text-curie-001": 2048,
"curie": 2048,
"davinci": 2048,
"code-cushman-002": 2048,
"code-cushman-001": 2048,
"text-davinci-003": 4000,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"text-davinci-002": 4096,
"code-davinci-002": 8000,
"code-davinci-001": 8000,
"gpt-4": 8192,
"gpt-4-0314": 8192,
"gpt-4-0613": 8192,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-16k-0613": 16385,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-4-32k-0613": 32768,
return token_dict[model_name]
def model_output_raw(self, response: Dict) -> Dict:
data = {}
if self.model_type == "chat_models":
data["text"] = response["choices"][0]["message"]["content"].strip(" \n")
elif self.model_type == "completion_models":
data["text"] = response["choices"][0]["text"]
data["usage"] = dict(response["usage"])
return data
def model_output(self, response, max_completion_length: int) -> Dict:
data = self.model_output_raw(response)
data["parsed"] =["text"], max_completion_length)
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment