Last active
March 8, 2023 08:52
-
-
Save josephspurrier/9530feb5b3c66e69850ef56ca8f5978a to your computer and use it in GitHub Desktop.
OpenAI/ChatGPT Terminal in Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"fmt" | |
"log" | |
"os" | |
"strings" | |
) | |
// API reference: https://platform.openai.com/docs/api-reference/completions | |
var modelSelector = true | |
func main() { | |
token := os.Getenv("OPENAI_TOKEN") | |
if len(token) == 0 { | |
log.Fatalln("could not load env var: OPENAI_TOKEN (generate one: https://platform.openai.com/account/api-keys)") | |
} | |
rc := NewRequestClient("https://api.openai.com", token) | |
rc.AddHeader("Content-Type", "application/json") | |
prompt := "\n$ " | |
fmt.Println("== OpenAI Terminal ==") | |
fmt.Println("# By default, typed text will be sent to the chat API (ChatGPT) unless an internal command is specified.") | |
fmt.Println("# Internal commands: exit, switchmodel, getmodel, createimage <description>") | |
fmt.Println("# Press 'Enter' twice when done typing because this app supports copy/paste.") | |
fmt.Println("# When you see '>', it's waiting for the response from OpenAI.") | |
for { | |
fmt.Print(prompt) | |
reader := bufio.NewReader(os.Stdin) | |
var input string | |
lastchar := ' ' | |
for { | |
char, _, err := reader.ReadRune() | |
if err != nil { | |
fmt.Println("Error reading input:", err) | |
return | |
} | |
if char == '\n' && lastchar == '\n' { | |
break | |
} | |
lastchar = char | |
input += string(char) | |
} | |
cleanInput := strings.TrimSuffix(input, "\n") | |
var rsp string | |
var err error | |
SendMsg: | |
switch true { | |
case cleanInput == "exit": | |
fmt.Printf("# Goodbye!\n") | |
os.Exit(0) | |
case cleanInput == "switchmodel": | |
modelSelector = !modelSelector | |
fmt.Printf("# Changed to model: %s\n", model()) | |
continue | |
case cleanInput == "getmodel": | |
fmt.Printf("# Current model: %s\n", model()) | |
continue | |
case strings.HasPrefix(cleanInput, "createimage"): | |
fmt.Print("> ") | |
rsp, err = sendImageRequest(rc, cleanInput[12:]) | |
default: | |
// By default, use chat mode. | |
fmt.Print("> ") | |
rsp, err = sendChatRequest(rc, cleanInput) | |
} | |
// Handle any errors. | |
if err != nil { | |
fmt.Println("Error:", err.Error()) | |
// If it's busy, swap the model. | |
if strings.Contains(err.Error(), "That model is currently overloaded with other requests") { | |
modelSelector = !modelSelector | |
fmt.Printf("# Retrying with model: %s...\n", model()) | |
goto SendMsg | |
} else { | |
continue | |
} | |
} | |
fmt.Println(rsp) | |
} | |
} | |
type ChatCompletionRequest struct { | |
Model string `json:"model"` | |
Messages []ChatCompletionRequestMessages `json:"messages"` | |
} | |
type ChatCompletionRequestMessages struct { | |
Role string `json:"role"` | |
Content string `json:"content"` | |
} | |
type ChatCompletionResponse struct { | |
ID string `json:"id"` | |
Object string `json:"object"` | |
Created int `json:"created"` | |
Model string `json:"model"` | |
Usage struct { | |
PromptTokens int `json:"prompt_tokens"` | |
CompletionTokens int `json:"completion_tokens"` | |
TotalTokens int `json:"total_tokens"` | |
} `json:"usage"` | |
Choices []struct { | |
Message struct { | |
Role string `json:"role"` | |
Content string `json:"content"` | |
} `json:"message"` | |
FinishReason string `json:"finish_reason"` | |
Index int `json:"index"` | |
} `json:"choices"` | |
} | |
func sendChatRequest(rc *RequestClient, input string) (string, error) { | |
req := ChatCompletionRequest{ | |
Model: model(), | |
Messages: []ChatCompletionRequestMessages{ | |
{ | |
Role: "user", | |
Content: input, | |
}, | |
}, | |
} | |
rsp := ChatCompletionResponse{} | |
err := rc.Post("/v1/chat/completions", req, &rsp) | |
if err != nil { | |
return "", fmt.Errorf("error sending request: %w", err) | |
} | |
out := "" | |
for _, v := range rsp.Choices { | |
out += v.Message.Content | |
} | |
// Trim the newlines from the beginning. | |
out = strings.TrimPrefix(out, "\n\n") | |
return out, nil | |
} | |
// model returns the current model used for requests. | |
func model() string { | |
model := "gpt-3.5-turbo" | |
if modelSelector { | |
model = "gpt-3.5-turbo-0301" | |
} | |
return model | |
} | |
type ImageRequest struct { | |
Prompt string `json:"prompt"` | |
N int `json:"n"` | |
Size string `json:"size"` | |
} | |
type ImageResponse struct { | |
Created int `json:"created"` | |
Data []struct { | |
URL string `json:"url"` | |
} `json:"data"` | |
} | |
func sendImageRequest(rc *RequestClient, input string) (string, error) { | |
req := ImageRequest{ | |
Prompt: input, | |
N: 1, | |
Size: "1024x1024", | |
} | |
rsp := ImageResponse{} | |
err := rc.Post("/v1/images/generations", req, &rsp) | |
if err != nil { | |
return "", fmt.Errorf("error sending request: %w", err) | |
} | |
out := "" | |
for _, v := range rsp.Data { | |
out += v.URL | |
} | |
// Trim the newlines from the beginning. | |
out = strings.TrimPrefix(out, "\n\n") | |
return out, nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"encoding/json" | |
"errors" | |
"fmt" | |
"io" | |
"net/http" | |
"reflect" | |
) | |
// RequestClient is for making HTTP requests. | |
type RequestClient struct { | |
baseURL string | |
bearerToken string | |
headerKeys []string | |
headerValues []string | |
} | |
// NewRequestClient returns a request client. | |
func NewRequestClient(baseURL string, bearerToken string) *RequestClient { | |
return &RequestClient{ | |
baseURL: baseURL, | |
bearerToken: bearerToken, | |
headerKeys: make([]string, 0), | |
headerValues: make([]string, 0), | |
} | |
} | |
// AddHeader adds a header to all requests. | |
func (c *RequestClient) AddHeader(name string, value string) { | |
c.headerKeys = append(c.headerKeys, name) | |
c.headerValues = append(c.headerValues, value) | |
} | |
// Get makes a GET request. | |
func (c *RequestClient) Get(urlSuffix string, sendData interface{}, returnData interface{}) error { | |
// Ensure supported returnData was passed in (should be pointer). | |
if returnData != nil { | |
v := reflect.ValueOf(returnData) | |
if v.Kind() != reflect.Ptr { | |
return errors.New("data must pass a pointer, not a value") | |
} | |
} | |
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", c.baseURL, urlSuffix), nil) | |
if err != nil { | |
return fmt.Errorf("error creating request: %v", err.Error()) | |
} | |
if len(c.bearerToken) > 0 { | |
req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", c.bearerToken)) | |
} | |
// Add each header. | |
for index, key := range c.headerKeys { | |
req.Header.Set(key, c.headerValues[index]) | |
} | |
client := http.DefaultClient | |
resp, err := client.Do(req) | |
if err != nil { | |
return fmt.Errorf("error creating client: %v", err.Error()) | |
} | |
body, err := io.ReadAll(resp.Body) | |
if err != nil { | |
return fmt.Errorf("error reading body: %v", err.Error()) | |
} | |
if resp.StatusCode != 200 { | |
return fmt.Errorf("error GET response (%v): %v", resp.StatusCode, string(body)) | |
} | |
if returnData != nil { | |
err = json.Unmarshal(body, returnData) | |
if err != nil { | |
return fmt.Errorf("error unmarshal: %v", err.Error()) | |
} | |
} | |
return nil | |
} | |
// Post makes a POST request. | |
func (c *RequestClient) Post(urlSuffix string, sendData interface{}, returnData interface{}) error { | |
// Ensure supported returnData was passed in (should be pointer). | |
if returnData != nil { | |
v := reflect.ValueOf(returnData) | |
if v.Kind() != reflect.Ptr { | |
return errors.New("data must pass a pointer, not a value") | |
} | |
} | |
var err error | |
var req *http.Request | |
if sendData != nil { | |
// Send data with the request if passed in. | |
sendJSON, err := json.Marshal(sendData) | |
if err != nil { | |
return err | |
} | |
req, err = http.NewRequest(http.MethodPost, fmt.Sprintf("%v%v", c.baseURL, urlSuffix), bytes.NewReader(sendJSON)) | |
if err != nil { | |
return fmt.Errorf("error creating request: %v", err.Error()) | |
} | |
} else { | |
// Don't send data in with the request if passed in. | |
req, err = http.NewRequest(http.MethodPost, fmt.Sprintf("%v%v", c.baseURL, urlSuffix), nil) | |
if err != nil { | |
return fmt.Errorf("error creating request: %v", err.Error()) | |
} | |
} | |
if len(c.bearerToken) > 0 { | |
req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", c.bearerToken)) | |
} | |
// Add each header. | |
for index, key := range c.headerKeys { | |
req.Header.Set(key, c.headerValues[index]) | |
} | |
client := http.DefaultClient | |
resp, err := client.Do(req) | |
if err != nil { | |
return fmt.Errorf("error creating client: %v", err.Error()) | |
} | |
body, err := io.ReadAll(resp.Body) | |
if err != nil { | |
return fmt.Errorf("error reading body: %v", err.Error()) | |
} | |
if resp.StatusCode != 200 { | |
return fmt.Errorf("error GET response (%v): %v", resp.StatusCode, string(body)) | |
} | |
if returnData != nil { | |
err = json.Unmarshal(body, returnData) | |
if err != nil { | |
return fmt.Errorf("error unmarshal: %v", err.Error()) | |
} | |
} | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment