Skip to content

Instantly share code, notes, and snippets.

@abatilo
Created April 2, 2023 21:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save abatilo/56521166eae5812a116bb1476e1a764f to your computer and use it in GitHub Desktop.
Save abatilo/56521166eae5812a116bb1476e1a764f to your computer and use it in GitHub Desktop.
Build index and search index with OpenAI
package gogptindex
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"github.com/rs/zerolog"
"github.com/sashabaranov/go-openai"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/time/rate"
)
const (
FlagDataDir = "data-dir"
FlagOutputDir = "output-dir"
)
//go:embed index.json
var indexJSON []byte
type Document struct {
Content string `json:"text"`
Embedding []float32 `json:"vector"`
}
func serve(log zerolog.Logger) *cobra.Command {
cmd := &cobra.Command{
Use: "serve",
Short: "Serve the index in a web server",
Run: func(cmd *cobra.Command, args []string) {
log.Info().Str("index", string(indexJSON)).Msg("Serving index")
ctx := context.Background()
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
var index map[string]Document
if err := json.Unmarshal(indexJSON, &index); err != nil {
log.Fatal().Err(err).Msg("Failed to unmarshal index")
}
type QueryResponse struct {
Response string `json:"response"`
}
// Read config
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query().Get("query")
log.Debug().Str("query", query).Msg("Received query")
embeddingsResponse, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: []string{query},
Model: openai.AdaEmbeddingV2,
})
if err != nil {
log.Error().Err(err).Msg("Failed to get embeddings")
http.Error(w, "Failed to get embeddings", http.StatusInternalServerError)
return
}
queryEmbedding := embeddingsResponse.Data[0].Embedding
// Get 3 most similar
neighbors := nearestNeighbors(queryEmbedding, index, 3)
neighborsContent := []string{}
for _, neighbor := range neighbors {
log.Debug().Str("content", neighbor.Content).Msg("Found neighbor")
neighborsContent = append(neighborsContent, neighbor.Content)
}
prompt := fmt.Sprintf(
`Context:
%s
%s
%s
Question:
%s
Answer:
`,
neighborsContent[0],
neighborsContent[1],
neighborsContent[2],
query,
)
log.Debug().Str("prompt", prompt).Msg("Generated prompt")
resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: "system",
Content: "You are a question answering bot. You will be given some context and a question. You must use the context to answer the question. If the context is not enough to answer the question, you will response with \"The answer to that question is not in the context data.\"",
},
{
Role: "user",
Content: prompt,
},
},
})
if err != nil {
log.Error().Err(err).Msg("Failed to get completion")
http.Error(w, "Failed to get completion", http.StatusInternalServerError)
return
}
log.Debug().Str("completion", resp.Choices[0].Message.Content).Msg("Got completion")
response := QueryResponse{
Response: resp.Choices[0].Message.Content,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
log.Error().Err(err).Msg("Failed to encode response")
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
})
srv := &http.Server{
Addr: ":8000",
Handler: mux,
}
// Register signal handlers for graceful shutdown
done := make(chan struct{})
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-quit
log.Info().Msg("Shutting down gracefully")
_ = srv.Shutdown(ctx)
close(done)
}()
log.Info().Msg("Starting server on port 8080")
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
log.Fatal().Err(err).Msg("Failed to start server")
}
<-done
},
}
return cmd
}
func build(log zerolog.Logger) *cobra.Command {
cmd := &cobra.Command{
Use: "build",
Short: "Build the index",
Run: func(cmd *cobra.Command, args []string) {
ctx := context.Background()
log.Debug().Msg("Building index")
// List every single file in ./data
files, err := os.ReadDir(viper.GetString(FlagDataDir))
if err != nil {
log.Fatal().Err(err).Msg("Failed to read data directory")
}
// limit requests to 3000 per minute == 50 per second
// https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits
limiter := rate.NewLimiter(50, 50)
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
documents := make(map[string]Document, len(files))
for _, file := range files {
_ = limiter.Wait(ctx)
log.Info().Str("file", file.Name()).Msg("Processing file")
fileBytes, err := os.ReadFile(filepath.Join(viper.GetString(FlagDataDir), file.Name()))
if err != nil {
log.Fatal().Err(err).Msg("Failed to read file")
}
fileContents := string(fileBytes)
embeddingsResponse, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: []string{fileContents},
Model: openai.AdaEmbeddingV2,
})
if err != nil {
log.Fatal().Err(err).Msg("Failed to get embeddings")
}
documents[file.Name()] = Document{
Content: fileContents,
Embedding: embeddingsResponse.Data[0].Embedding,
}
}
log.Debug().Msgf("Done: %#v", documents)
output := filepath.Join(viper.GetString(FlagOutputDir), "index.json")
f, err := os.Create(output)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create index file")
}
err = json.NewEncoder(f).Encode(documents)
if err != nil {
log.Fatal().Err(err).Msg("Failed to write index file")
}
},
}
cmd.PersistentFlags().String(FlagDataDir, "./cmd/resumegpt/gogptindex/data", "Path to source training data")
cmd.PersistentFlags().String(FlagOutputDir, "./cmd/resumegpt/gogptindex/", "Path to output the embeddings")
viper.BindPFlags(cmd.PersistentFlags())
return cmd
}
func Cmd(log zerolog.Logger) *cobra.Command {
cmd := &cobra.Command{
Use: "index",
Short: "Commands for managing our GPT index",
}
cmd.AddCommand(serve(log), build(log))
return cmd
}
package gogptindex
import "math"
// cosineSimilarity returns the cosine similarity between two vectors.
func cosineSimilarity(a []float32, b []float32) float32 {
var dotProduct float32
var magnitudeA float32
var magnitudeB float32
for i := range a {
dotProduct += a[i] * b[i]
magnitudeA += a[i] * a[i]
magnitudeB += b[i] * b[i]
}
return dotProduct / (magnitudeA * magnitudeB)
}
// nearestNeighbors returns the K nearest neighbors to the query vector.
func nearestNeighbors(query []float32, index map[string]Document, k int) map[string]Document {
neighbors := make([]Document, k)
similarities := make([]float32, k)
for i := range similarities {
similarities[i] = float32(math.Inf(-1))
}
for _, doc := range index {
sim := cosineSimilarity(query, doc.Embedding)
// Find the lowest similarity
lowestSimilarity := similarities[0]
lowestSimilarityIndex := 0
for i, sim := range similarities {
if sim < lowestSimilarity {
lowestSimilarity = sim
lowestSimilarityIndex = i
}
}
// If the similarity is higher than the lowest similarity, replace it
if sim > lowestSimilarity {
similarities[lowestSimilarityIndex] = sim
neighbors[lowestSimilarityIndex] = doc
}
}
// Convert to map
neighborsMap := make(map[string]Document)
for _, doc := range neighbors {
neighborsMap[doc.Content] = doc
}
return neighborsMap
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment