Created
April 2, 2023 21:42
-
-
Save abatilo/56521166eae5812a116bb1476e1a764f to your computer and use it in GitHub Desktop.
Build index and search index with OpenAI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gogptindex | |
import ( | |
"context" | |
_ "embed" | |
"encoding/json" | |
"fmt" | |
"net/http" | |
"os" | |
"os/signal" | |
"path/filepath" | |
"syscall" | |
"github.com/rs/zerolog" | |
"github.com/sashabaranov/go-openai" | |
"github.com/spf13/cobra" | |
"github.com/spf13/viper" | |
"golang.org/x/time/rate" | |
) | |
const ( | |
FlagDataDir = "data-dir" | |
FlagOutputDir = "output-dir" | |
) | |
//go:embed index.json | |
var indexJSON []byte | |
type Document struct { | |
Content string `json:"text"` | |
Embedding []float32 `json:"vector"` | |
} | |
func serve(log zerolog.Logger) *cobra.Command { | |
cmd := &cobra.Command{ | |
Use: "serve", | |
Short: "Serve the index in a web server", | |
Run: func(cmd *cobra.Command, args []string) { | |
log.Info().Str("index", string(indexJSON)).Msg("Serving index") | |
ctx := context.Background() | |
client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) | |
var index map[string]Document | |
if err := json.Unmarshal(indexJSON, &index); err != nil { | |
log.Fatal().Err(err).Msg("Failed to unmarshal index") | |
} | |
type QueryResponse struct { | |
Response string `json:"response"` | |
} | |
// Read config | |
mux := http.NewServeMux() | |
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
query := r.URL.Query().Get("query") | |
log.Debug().Str("query", query).Msg("Received query") | |
embeddingsResponse, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ | |
Input: []string{query}, | |
Model: openai.AdaEmbeddingV2, | |
}) | |
if err != nil { | |
log.Error().Err(err).Msg("Failed to get embeddings") | |
http.Error(w, "Failed to get embeddings", http.StatusInternalServerError) | |
return | |
} | |
queryEmbedding := embeddingsResponse.Data[0].Embedding | |
// Get 3 most similar | |
neighbors := nearestNeighbors(queryEmbedding, index, 3) | |
neighborsContent := []string{} | |
for _, neighbor := range neighbors { | |
log.Debug().Str("content", neighbor.Content).Msg("Found neighbor") | |
neighborsContent = append(neighborsContent, neighbor.Content) | |
} | |
prompt := fmt.Sprintf( | |
`Context: | |
%s | |
%s | |
%s | |
Question: | |
%s | |
Answer: | |
`, | |
neighborsContent[0], | |
neighborsContent[1], | |
neighborsContent[2], | |
query, | |
) | |
log.Debug().Str("prompt", prompt).Msg("Generated prompt") | |
resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ | |
Model: openai.GPT3Dot5Turbo, | |
Messages: []openai.ChatCompletionMessage{ | |
{ | |
Role: "system", | |
Content: "You are a question answering bot. You will be given some context and a question. You must use the context to answer the question. If the context is not enough to answer the question, you will response with \"The answer to that question is not in the context data.\"", | |
}, | |
{ | |
Role: "user", | |
Content: prompt, | |
}, | |
}, | |
}) | |
if err != nil { | |
log.Error().Err(err).Msg("Failed to get completion") | |
http.Error(w, "Failed to get completion", http.StatusInternalServerError) | |
return | |
} | |
log.Debug().Str("completion", resp.Choices[0].Message.Content).Msg("Got completion") | |
response := QueryResponse{ | |
Response: resp.Choices[0].Message.Content, | |
} | |
if err := json.NewEncoder(w).Encode(response); err != nil { | |
log.Error().Err(err).Msg("Failed to encode response") | |
http.Error(w, "Failed to encode response", http.StatusInternalServerError) | |
return | |
} | |
}) | |
srv := &http.Server{ | |
Addr: ":8000", | |
Handler: mux, | |
} | |
// Register signal handlers for graceful shutdown | |
done := make(chan struct{}) | |
quit := make(chan os.Signal, 1) | |
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) | |
go func() { | |
<-quit | |
log.Info().Msg("Shutting down gracefully") | |
_ = srv.Shutdown(ctx) | |
close(done) | |
}() | |
log.Info().Msg("Starting server on port 8080") | |
if err := srv.ListenAndServe(); err != http.ErrServerClosed { | |
log.Fatal().Err(err).Msg("Failed to start server") | |
} | |
<-done | |
}, | |
} | |
return cmd | |
} | |
func build(log zerolog.Logger) *cobra.Command { | |
cmd := &cobra.Command{ | |
Use: "build", | |
Short: "Build the index", | |
Run: func(cmd *cobra.Command, args []string) { | |
ctx := context.Background() | |
log.Debug().Msg("Building index") | |
// List every single file in ./data | |
files, err := os.ReadDir(viper.GetString(FlagDataDir)) | |
if err != nil { | |
log.Fatal().Err(err).Msg("Failed to read data directory") | |
} | |
// limit requests to 3000 per minute == 50 per second | |
// https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits | |
limiter := rate.NewLimiter(50, 50) | |
client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) | |
documents := make(map[string]Document, len(files)) | |
for _, file := range files { | |
_ = limiter.Wait(ctx) | |
log.Info().Str("file", file.Name()).Msg("Processing file") | |
fileBytes, err := os.ReadFile(filepath.Join(viper.GetString(FlagDataDir), file.Name())) | |
if err != nil { | |
log.Fatal().Err(err).Msg("Failed to read file") | |
} | |
fileContents := string(fileBytes) | |
embeddingsResponse, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ | |
Input: []string{fileContents}, | |
Model: openai.AdaEmbeddingV2, | |
}) | |
if err != nil { | |
log.Fatal().Err(err).Msg("Failed to get embeddings") | |
} | |
documents[file.Name()] = Document{ | |
Content: fileContents, | |
Embedding: embeddingsResponse.Data[0].Embedding, | |
} | |
} | |
log.Debug().Msgf("Done: %#v", documents) | |
output := filepath.Join(viper.GetString(FlagOutputDir), "index.json") | |
f, err := os.Create(output) | |
if err != nil { | |
log.Fatal().Err(err).Msg("Failed to create index file") | |
} | |
err = json.NewEncoder(f).Encode(documents) | |
if err != nil { | |
log.Fatal().Err(err).Msg("Failed to write index file") | |
} | |
}, | |
} | |
cmd.PersistentFlags().String(FlagDataDir, "./cmd/resumegpt/gogptindex/data", "Path to source training data") | |
cmd.PersistentFlags().String(FlagOutputDir, "./cmd/resumegpt/gogptindex/", "Path to output the embeddings") | |
viper.BindPFlags(cmd.PersistentFlags()) | |
return cmd | |
} | |
func Cmd(log zerolog.Logger) *cobra.Command { | |
cmd := &cobra.Command{ | |
Use: "index", | |
Short: "Commands for managing our GPT index", | |
} | |
cmd.AddCommand(serve(log), build(log)) | |
return cmd | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gogptindex | |
import "math" | |
// cosineSimilarity returns the cosine similarity between two vectors. | |
func cosineSimilarity(a []float32, b []float32) float32 { | |
var dotProduct float32 | |
var magnitudeA float32 | |
var magnitudeB float32 | |
for i := range a { | |
dotProduct += a[i] * b[i] | |
magnitudeA += a[i] * a[i] | |
magnitudeB += b[i] * b[i] | |
} | |
return dotProduct / (magnitudeA * magnitudeB) | |
} | |
// nearestNeighbors returns the K nearest neighbors to the query vector. | |
func nearestNeighbors(query []float32, index map[string]Document, k int) map[string]Document { | |
neighbors := make([]Document, k) | |
similarities := make([]float32, k) | |
for i := range similarities { | |
similarities[i] = float32(math.Inf(-1)) | |
} | |
for _, doc := range index { | |
sim := cosineSimilarity(query, doc.Embedding) | |
// Find the lowest similarity | |
lowestSimilarity := similarities[0] | |
lowestSimilarityIndex := 0 | |
for i, sim := range similarities { | |
if sim < lowestSimilarity { | |
lowestSimilarity = sim | |
lowestSimilarityIndex = i | |
} | |
} | |
// If the similarity is higher than the lowest similarity, replace it | |
if sim > lowestSimilarity { | |
similarities[lowestSimilarityIndex] = sim | |
neighbors[lowestSimilarityIndex] = doc | |
} | |
} | |
// Convert to map | |
neighborsMap := make(map[string]Document) | |
for _, doc := range neighbors { | |
neighborsMap[doc.Content] = doc | |
} | |
return neighborsMap | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment