|
package main |
|
|
|
import ( |
|
"context" |
|
"fmt" |
|
"os" |
|
"os/exec" |
|
"strings" |
|
"time" |
|
|
|
"github.com/AlecAivazis/survey/v2" |
|
"github.com/PullRequestInc/go-gpt3" |
|
"github.com/briandowns/spinner" |
|
) |
|
|
|
func isGitRepo() bool { |
|
cmd := exec.Command("git", "rev-parse", "--is-inside-work-tree") |
|
_, err := cmd.Output() |
|
if err != nil { |
|
return false |
|
} |
|
return true |
|
} |
|
|
|
func getStagedFiles() []string { |
|
// Get the staged files, ignore exit code of exec.Command |
|
cmd := exec.Command("git", "diff", "--staged", "--name-only") |
|
out, _ := cmd.Output() |
|
|
|
staged := strings.Split(strings.TrimSpace(string(out)), "\n") |
|
|
|
// if there are no staged files, return an empty slice |
|
if staged[0] == "" { |
|
return []string{} |
|
} |
|
|
|
return staged |
|
} |
|
|
|
func getDiff(stagedFiles []string) string { |
|
cmd := exec.Command("git", "diff", "--staged", "--unified=0", strings.Join(stagedFiles, " ")) |
|
out, err := cmd.Output() |
|
if err != nil { |
|
fmt.Println(err) |
|
os.Exit(1) |
|
} |
|
return string(out) |
|
} |
|
|
|
// return true if the diffLines are less than or equal to maxLines and also return the maxLines |
|
func checkDiffLinesLength(diffLines string) (bool, int) { |
|
maxLines := 150 |
|
return len(strings.Split(diffLines, "\n")) <= maxLines, maxLines |
|
} |
|
|
|
func generateCommitMessages(diffLines string, client gpt3.Client) string { |
|
prompt := fmt.Sprintf("\n\n%s\n\nPlease suggest 5 semantic commit messages (prefixed with feat for features, fix for fixes, docs...) based on the above git diff, do not prefix lines with numbers, mention files and changes, use more than 100 chars in line if needed:\n", diffLines) |
|
|
|
s := spinner.New(spinner.CharSets[14], 100*time.Millisecond) |
|
s.Start() |
|
response, err := client.Completion(context.Background(), gpt3.CompletionRequest{ |
|
Prompt: []string{prompt}, |
|
MaxTokens: gpt3.IntPtr(350), |
|
// Stop: []string{"."}, |
|
Echo: true, |
|
}) |
|
s.Stop() |
|
|
|
if err != nil { |
|
fmt.Println(err) |
|
os.Exit(1) |
|
} |
|
|
|
choices := response.Choices[0].Text |
|
// get rid of the prompt |
|
choices = strings.ReplaceAll(choices, prompt, "") |
|
|
|
// remove the numbers from the beginning of the lines |
|
choices = strings.TrimSpace(strings.ReplaceAll(choices, "^\\d+\\.\\s?", "")) |
|
|
|
// remove empty choices |
|
choices = strings.TrimSpace(strings.ReplaceAll(choices, "\n\n", "\n")) |
|
|
|
return choices |
|
} |
|
|
|
func selectCommitMessage(commitMessages string) string { |
|
choicesArr := strings.Split(commitMessages, "\n") |
|
var qs = []*survey.Question{ |
|
{ |
|
Name: "commitMessage", |
|
Prompt: &survey.Select{ |
|
Message: "Please select a commit message:", |
|
Options: choicesArr, |
|
}, |
|
}, |
|
{ |
|
Name: "tweakMessage", |
|
Prompt: &survey.Confirm{ |
|
Message: "Do you want to change the commit message before committing?", |
|
}, |
|
}, |
|
} |
|
answers := struct { |
|
CommitMessage string |
|
TweakMessage bool |
|
}{} |
|
err := survey.Ask(qs, &answers) |
|
if err != nil { |
|
fmt.Println(err) |
|
os.Exit(1) |
|
} |
|
commitMessage := answers.CommitMessage |
|
if answers.TweakMessage { |
|
prompt := &survey.Input{ |
|
Message: "Enter the commit message:", |
|
Default: commitMessage, |
|
} |
|
err = survey.AskOne(prompt, &commitMessage) |
|
if err != nil { |
|
fmt.Println(err) |
|
os.Exit(1) |
|
} |
|
} |
|
return commitMessage |
|
} |
|
|
|
func commitStagedFiles(commitMessage string) { |
|
gitCommand := fmt.Sprintf("git commit -m %q", commitMessage) |
|
fmt.Printf("Executing: %s\n", gitCommand) |
|
cmd := exec.Command("git", "commit", "-m", commitMessage) |
|
_, err := cmd.Output() |
|
if err != nil { |
|
fmt.Println(err) |
|
os.Exit(1) |
|
} |
|
} |
|
|
|
func main() { |
|
// throw an error if the OPENAI_API_KEY is not set |
|
apiKey := os.Getenv("OPENAI_API_KEY") |
|
if apiKey == "" { |
|
fmt.Println("Error: OPENAI_API_KEY is not set") |
|
os.Exit(1) |
|
} |
|
|
|
// Create an OpenAI API client |
|
client := gpt3.NewClient(apiKey, gpt3.WithDefaultEngine(gpt3.TextDavinci003Engine)) |
|
|
|
// Check if command is executed in git repo |
|
if !isGitRepo() { |
|
fmt.Println("Error: Command must be executed in a git repository") |
|
os.Exit(1) |
|
} |
|
|
|
// Get the staged files |
|
stagedFiles := getStagedFiles() |
|
if len(stagedFiles) == 0 { |
|
fmt.Println("Error: No staged files, nothing to commit") |
|
os.Exit(1) |
|
} |
|
|
|
// Get the diff for the staged files |
|
diff := getDiff(stagedFiles) |
|
|
|
// Throw an error if there are more than allowed lines in the diff |
|
ok, maxLines := checkDiffLinesLength(diff) |
|
if !ok { |
|
fmt.Printf("Error: Diff lines are more than %d\n", maxLines) |
|
os.Exit(1) |
|
} |
|
|
|
// Use GPT-3 to generate 5 commit messages based on the diff |
|
commitMessages := generateCommitMessages(diff, client) |
|
|
|
// Prompt the user to select a commit message |
|
commitMessage := selectCommitMessage(commitMessages) |
|
|
|
// Print the selected commit message |
|
fmt.Printf("Commit message: %s\n", commitMessage) |
|
|
|
// Use the selected commit message to commit the staged files |
|
commitStagedFiles(commitMessage) |
|
} |
|
|