Last active
July 20, 2023 06:08
-
-
Save eljojo/497fce6d11269578a330f8304919b820 to your computer and use it in GitHub Desktop.
Simple Open AI API Client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
module Ai | |
class Function | |
# Returns a formatted function payload | |
def to_payload | |
{ | |
"name" => name, | |
"description" => description, | |
"parameters" => parameters, | |
} | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
module Ai | |
class FunctionAnswer | |
attr_reader :functions | |
def initialize(functions) | |
@functions = Array(functions) | |
end | |
def process(api_response) | |
# Ensure there is a function_call in the response | |
function_call = api_response["choices"].first["message"].fetch("function_call", {}) | |
function_name = function_call["name"] | |
function_arguments = function_call["arguments"] | |
# Ensure the function_call has a name and arguments | |
unless function_name && function_arguments | |
raise StandardError, "No function call in the response" | |
end | |
# Find the function in the response matches the expected function | |
matching_function = functions.find { |f| f.name == function_name } | |
# Ensure a matching function is found | |
unless matching_function | |
raise StandardError, "Unexpected function name in the response" | |
end | |
# Parse the arguments using the matching function's specific parse_answer method | |
arguments = JSON.parse(function_arguments) | |
matching_function.parse_answer(**arguments.symbolize_keys) | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
module Ai | |
module Functions | |
class Weather < Function | |
def name | |
"get_current_weather" | |
end | |
def description | |
"Get the current weather in a given location" | |
end | |
def parameters | |
{ | |
type: "object", | |
properties: { | |
location: { | |
type: "string", | |
description: "The city and state, e.g. San Francisco, CA", | |
}, | |
unit: { | |
type: "string", | |
enum: ["celsius", "fahrenheit"], | |
}, | |
}, | |
required: ["location", "unit"], | |
} | |
end | |
def parse_answer(location:, unit:) | |
puts "#{location}: #{unit}" | |
end | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
module Ai | |
class OpenAiService | |
BASE_URL = "https://api.openai.com/v1" | |
# Calls the OpenAI chat-based model | |
def process(payload) | |
uri = URI("#{BASE_URL}/chat/completions") | |
headers = { | |
"Content-Type" => "application/json", | |
"Authorization" => "Bearer #{api_key}", | |
} | |
http = Net::HTTP.new(uri.host, uri.port) | |
http.use_ssl = true | |
http.read_timeout = 2.minutes | |
request = Net::HTTP::Post.new(uri.request_uri, headers) | |
request.body = payload.to_json | |
response = http.request(request) | |
if response.is_a?(Net::HTTPSuccess) | |
JSON.parse(response.body) | |
else | |
raise StandardError, "Failed to process text. Error: #{response.body}" | |
end | |
end | |
def api_key | |
Rails.application.credentials.openai&.api_key || raise("missing OpenAI API Key") | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
module Ai | |
class PromptAndAnswer | |
attr_reader :prompt, :expected_answer | |
def initialize(prompt, expected_answer) | |
@prompt = prompt | |
@expected_answer = expected_answer | |
end | |
def process(open_ai_service = OpenAiService.new) | |
response = prompt.process(open_ai_service) | |
expected_answer.process(response) | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
module Ai | |
module Prompts | |
extend self | |
SMART_WRITER = "You are a very smart writer." | |
def rewrite_text(text) | |
prompt = SimplePrompt.new("#{SMART_WRITER} Please improve the following blog post, maintaining the original tone of the author.", text) | |
prompt.with_answer(all_text: true) | |
end | |
def tell_joke | |
prompt = SimplePrompt.new("#{SMART_WRITER} Answer the user's request", "tell me a joke with two parts split by a colon") | |
prompt.with_answer(regex: /:(?<joke>.+)/) | |
end | |
def kitchen | |
function = Functions::Weather.new | |
prompt = SimplePrompt.new("#{SMART_WRITER} Answer the user's request", "What's the weather in ottawa?", functions: [function]) | |
prompt.with_answer | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
module Ai | |
class SimplePrompt | |
attr_accessor :system_text, :user_text, :max_tokens, :functions | |
def initialize(system_text, user_text, max_tokens: 1024, temperature: 0.5, functions: []) | |
@system_text = system_text | |
@user_text = user_text | |
@max_tokens = max_tokens | |
@functions = Array(functions).presence | |
@temperature = temperature # how creative, 0 is none, 1 is very | |
end | |
def process(open_ai_service = OpenAiService.new) | |
open_ai_service.process(to_payload) | |
end | |
# Returns a formatted message payload | |
def to_payload | |
{ | |
"model" => "gpt-3.5-turbo-16k-0613", | |
"messages" => messages, | |
"temperature" => @temperature, | |
"max_tokens" => @max_tokens, | |
"functions" => @functions&.map(&:to_payload), | |
}.compact | |
end | |
def messages | |
[ | |
{ | |
"role" => "system", | |
"content" => @system_text, | |
}, | |
{ | |
"role" => "user", | |
"content" => @user_text, | |
}, | |
] | |
end | |
def with_answer(**text_opts) | |
if text_opts.present? | |
return PromptAndAnswer.new(self, TextAnswer.new(**text_opts)) | |
end | |
if @functions.present? | |
return PromptAndAnswer.new(self, FunctionAnswer.new(@functions)) | |
end | |
raise "no text_opts or functions specified for prompt #{self}" | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
module Ai | |
class TextAnswer | |
def initialize(regex: nil, all_text: nil) | |
@regex = regex | |
@all_text = all_text | |
raise("unsure how to parse answer! please set regex: or all_text: true") unless regex || all_text | |
end | |
def process(api_response) | |
content = api_response["choices"].first["message"]["content"] | |
if content.present? | |
return content if @all_text | |
match = content.match(@regex) | |
raise StandardError, "No match found in the response: #{content}" if match.nil? | |
match | |
else | |
raise StandardError, "No content in the response" | |
end | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can use this like so: