Skip to content

Instantly share code, notes, and snippets.

@ibuildthecloud
Created June 3, 2024 21:42
Show Gist options
  • Save ibuildthecloud/ec96353c27502a99371387769594277d to your computer and use it in GitHub Desktop.
Save ibuildthecloud/ec96353c27502a99371387769594277d to your computer and use it in GitHub Desktop.
Simple function calling proxy
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"slices"
"strings"
openai "github.com/gptscript-ai/chat-completion-client"
)
func writeTool(buf *strings.Builder, tool openai.Tool) {
buf.WriteString("// ")
buf.WriteString(tool.Function.Description)
buf.WriteString("\n")
buf.WriteString("type ")
buf.WriteString(tool.Function.Name)
buf.WriteString(" = (_: {\n")
schema := tool.Function.Parameters.(map[string]any)
props, _ := schema["properties"].(map[string]any)
for name, prop := range props {
buf.WriteString("// ")
buf.WriteString(prop.(map[string]any)["description"].(string))
buf.WriteString("\n")
buf.WriteString(name)
buf.WriteString("?: string,\n")
}
buf.WriteString("}) => any;\n\n")
}
func translate(req openai.ChatCompletionRequest) openai.ChatCompletionRequest {
if len(req.Tools) == 0 {
return req
}
buf := strings.Builder{}
buf.WriteString("You can have the user call the following functions.\n")
buf.WriteString("```typescript\n")
buf.WriteString("## functions\n")
buf.WriteString("\n")
buf.WriteString("namespace functions {\n\n")
for _, tool := range req.Tools {
writeTool(&buf, tool)
}
buf.WriteString("} // namespace functions \n")
buf.WriteString("```\n\n")
buf.WriteString("To tell the user to call a function use the following format\n")
buf.WriteString("<CALL>functionToCall({\"arg1\":\"value1\",\"arg2\":\"value2\"})</CALL>\n\n")
req.Stream = true
req.Stop = []string{"</CALL>"}
var systemSet bool
for i, msg := range req.Messages {
if msg.Role == openai.ChatMessageRoleSystem && !systemSet {
msg.Content = buf.String() + msg.Content
req.Messages[i] = msg
systemSet = true
} else if msg.Role == openai.ChatMessageRoleAssistant && len(msg.ToolCalls) > 0 {
msg.Content = "<CALL>" + msg.ToolCalls[0].Function.Name + "(" + msg.ToolCalls[0].Function.Arguments + ")</CALL>"
msg.ToolCalls = nil
req.Messages[i] = msg
} else if msg.Role == openai.ChatMessageRoleTool {
msg.Role = openai.ChatMessageRoleUser
msg.Content = "<CALL_RESULT>" + msg.Content + "</CALL_RESULT>"
msg.FunctionCall = nil
req.Messages[i] = msg
}
}
if !systemSet {
req.Messages = slices.Insert(req.Messages, 0, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: buf.String(),
})
}
req.Tools = nil
return req
}
func toContent(resp *http.Response) io.Reader {
r, w := io.Pipe()
go func() {
defer w.Close()
reader := bufio.NewScanner(resp.Body)
for reader.Scan() {
var delta openai.ChatCompletionStreamResponse
fmt.Println("LINE: ", reader.Text())
line := strings.TrimPrefix(reader.Text(), "data: ")
if line == "" {
continue
}
if line == "[DONE]" {
break
}
if err := json.Unmarshal([]byte(line), &delta); err == nil {
if len(delta.Choices) > 0 {
fmt.Println("CONTENT: ", delta.Choices[0].Delta.Content)
_, _ = w.Write([]byte(delta.Choices[0].Delta.Content))
}
} else {
fmt.Print("ERR:", err.Error())
}
}
}()
return r
}
type scanner struct {
tokens [][]byte
}
func newSplitter() bufio.SplitFunc {
s := scanner{
tokens: [][]byte{
[]byte("<CALL>"),
[]byte("("),
[]byte(")!"),
},
}
return s.split
}
func (s *scanner) split(data []byte, atEOF bool) (advance int, token []byte, err error) {
if len(data) == 0 && atEOF {
return 0, nil, io.EOF
}
i := bytes.Index(data, s.tokens[0][:1])
if i == -1 {
return len(data), data, nil
} else if i > 0 {
return len(data[:i]), data[:i], nil
}
// Must start with <
if !bytes.HasPrefix(data, s.tokens[0]) {
if atEOF {
if string(data) == ")" {
return len(data), nil, io.EOF
}
return len(data), data, nil
} else if len(data) >= len(s.tokens[0]) {
return len(s.tokens[0]), data[:len(s.tokens[0])], nil
}
return 0, nil, nil
}
defer func() {
s.tokens = s.tokens[1:]
}()
return len(s.tokens[0]), data[:len(s.tokens[0])], nil
}
func main() {
var target = "http://station76.local:11434/v1/chat/completions"
//var target = "http://station76.local:8081/v1/chat/completions"
log.Print("Listening on :8081")
http.ListenAndServe(":8081", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req openai.ChatCompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
req = translate(req)
data, err := json.Marshal(req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
fmt.Println("REQUEST:", string(data))
newReq, err := http.NewRequest(http.MethodPost, target, bytes.NewReader(data))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
newReq.Header.Set("Content-Type", "text/event-stream")
httpResp, err := http.DefaultClient.Do(newReq)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
defer httpResp.Body.Close()
reader := bufio.NewScanner(toContent(httpResp))
reader.Split(newSplitter())
for reader.Scan() {
token := reader.Text()
if token == "<CALL>" {
var funcName string
for reader.Scan() {
token = reader.Text()
if token == "(" {
for reader.Scan() {
token = reader.Text()
_, _ = w.Write([]byte("data: "))
_ = json.NewEncoder(w).Encode(openai.ChatCompletionStreamResponse{
Choices: []openai.ChatCompletionStreamChoice{
{
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: openai.ChatMessageRoleAssistant,
ToolCalls: []openai.ToolCall{
{
Index: new(int),
Type: openai.ToolTypeFunction,
Function: openai.FunctionCall{
Name: strings.TrimPrefix(funcName, "functions."),
Arguments: token,
},
},
},
},
},
},
})
_, _ = w.Write([]byte("\n\n"))
}
} else {
fmt.Println("FUNC: ", token)
funcName += token
}
}
} else {
_, _ = w.Write([]byte("data: "))
_ = json.NewEncoder(w).Encode(openai.ChatCompletionStreamResponse{
Choices: []openai.ChatCompletionStreamChoice{
{
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: openai.ChatMessageRoleAssistant,
Content: token,
},
},
},
})
_, _ = w.Write([]byte("\n\n"))
}
}
}))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment