Created
March 6, 2023 04:35
-
-
Save oaustegard/96d2029b7a598eb80b411801ada43080 to your computer and use it in GitHub Desktop.
ChatBot.py -- an amended version of Simon Willison's wrapper class for easily implementing the new ChatGPT API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Simon Willison's ChatGPT API Wrapper Class: https://til.simonwillison.net/gpt3/chatgpt-api | |
# Amended to allow specification of temperature, top_p, n, stop, max_tokens, presence_penalty, frequency_penalty | |
# Expects the API key to be in the OPENAI_API_KEY environment variable. | |
import openai | |
class ChatBot: | |
def __init__(self, system="", | |
temperature=0.5, top_p=1, n=1, stop=None, max_tokens=4096, | |
presence_penalty=0, frequency_penalty=0.5): | |
""" | |
Create a new ChatBot instance. | |
Args: | |
system (str, optional): The system prompt. Defaults to "". | |
temperature (float, optional): "Randomness" value between 0 and 2. Defaults to 0.5. | |
top_p (int, optional): Nucleus sampling value from 0 to 1 -- if adjusting this, set temperature to 1. | |
Defaults to 1. | |
n (int, optional): How many choices to generate. If > 1 will generate a list of choices. Defaults to 1. | |
stop (_type_, optional): Up to 4 sequences where the API will stop generating further tokens. Defaults to None. | |
max_tokens (int, optional): The maximum number of tokens allowed for the generated answer. Max of 4096 minus | |
prompt. Defaults to 4096. | |
presence_penalty (float, optional): Number between -2.0 and 2.0. Positive values penalize new tokens based on | |
whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | |
Defaults to 0. | |
frequency_penalty (int, optional): Number between -2.0 and 2.0. Positive values penalize new tokens based on | |
their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line | |
verbatim. Defaults to 0.5. | |
""" | |
self.system = system | |
self.messages = [] | |
if self.system: | |
self.messages.append({"role": "system", "content": system}) | |
self.temperature = temperature | |
self.top_p = top_p | |
self.n = n | |
self.stop = stop or [] | |
self.max_tokens = max_tokens | |
self.presence_penalty = presence_penalty | |
self.frequency_penalty = frequency_penalty | |
def __call__(self, message): | |
""" | |
Send a message to the chatbot and return the response. | |
Args: | |
message (str): The message to send to the chatbot. | |
Returns: | |
str: The response(s) from the chatbot. | |
""" | |
self.messages.append({"role": "user", "content": message}) | |
result = self.execute() | |
self.messages.append({"role": "assistant", "content": result}) | |
return result | |
def execute(self, print_usage=False): | |
""" | |
Execute the chatbot. | |
Args: | |
print_usage (bool, optional): Whether to print the usage. Defaults to False. | |
Returns: | |
str: The response(s) from the chatbot. | |
""" | |
completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=self.messages, | |
temperature=self.temperature, top_p=self.top_p, | |
n=self.n, stop=self.stop, max_tokens=self.max_tokens, | |
presence_penalty=self.presence_penalty, | |
frequency_penalty=self.frequency_penalty) | |
if print_usage: | |
print(completion.usage) | |
# if n = 1 then simply return the first choice | |
if self.n == 1: | |
return completion.choices[0].message.content | |
# if n > 1 then return the list of choices as a \n\n separated string with a label for each choice | |
else: | |
return "\n\n".join([f"Choice {i+1}:\n{choice.message.content}" for i, choice in enumerate(completion.choices)]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment