Skip to content

Instantly share code, notes, and snippets.

@Fishwaldo
Created March 3, 2024 09:21
Show Gist options
  • Save Fishwaldo/b1f0736b3e9e2faeb947f515ee977406 to your computer and use it in GitHub Desktop.
Save Fishwaldo/b1f0736b3e9e2faeb947f515ee977406 to your computer and use it in GitHub Desktop.
generate go enums based on file comments
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"os"
"strings"
"text/template"
"unicode"
)
type Enum struct {
Name string
Value string
Type string
}
func main() {
files, err := ioutil.ReadDir(".")
if err != nil {
panic(err)
}
for _, file := range files {
if !strings.HasSuffix(file.Name(), ".go") {
continue
}
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, file.Name(), nil, parser.ParseComments)
if err != nil {
panic(err)
}
var enums []Enum
ast.Inspect(node, func(n ast.Node) bool {
decl, ok := n.(*ast.GenDecl)
if !ok || decl.Tok != token.CONST {
return true
}
if decl.Doc == nil {
return true
}
for _, comment := range decl.Doc.List {
if strings.HasPrefix(strings.TrimSpace(comment.Text), "//go:enum") {
enumType := strings.TrimSpace(strings.TrimPrefix(comment.Text, "//go:enum"))
for _, spec := range decl.Specs {
vspec := spec.(*ast.ValueSpec)
for _, name := range vspec.Names {
enums = append(enums, Enum{Name: name.Name, Value: vspec.Values[0].(*ast.BasicLit).Value, Type: enumType})
}
}
}
}
return true
})
if len(enums) == 0 {
continue
}
tmpl := template.Must(template.New("").Parse(`
// Code generated by go:enum tool; DO NOT EDIT.
package {{ .PackageName }}
import (
"encoding/json"
"errors"
"strings"
)
type {{ .Type | title }} interface {
String() string
MarshalJSON() ([]byte, error)
MarshalText() ([]byte, error)
UnmarshalJSON([]byte) error
UnmarshalText([]byte) error
Is{{ .Type | title }}()
}
type {{ .Type | lower }} struct {
name string
value int
}
func (r {{ .Type | lower }}) String() string {
return r.name
}
func (r {{ .Type | lower }}) MarshalJSON() ([]byte, error) {
return json.Marshal(r.name)
}
func (r {{ .Type | lower }}) MarshalText() (text []byte, err error) {
return []byte(r.name), nil
}
func (r *{{ .Type | lower }}) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
return r.UnmarshalText([]byte(s))
}
func (r *{{ .Type | lower }}) UnmarshalText(text []byte) error {
switch strings.ToLower(string(text)) {
{{ range .Enums }}
case "{{ .Name | lower }}":
*r = {{ .Name }}
{{ end }}
default:
return errors.New("invalid value")
}
return nil
}
func (r {{ .Type | lower }}) Is{{ .Type | title }}() {}
{{ range .Enums }}
var {{ .Name }} {{ $.Type | title }} = {{ $.Type | lower }}{"{{ .Name | lower }}", {{ .Value }}}
{{ end }}
`))
tmpl.Funcs(template.FuncMap{
"lower": func(s string) string {
if len(s) == 0 {
return s
}
return strings.ToLower(s)
},
"title": func(s string) string {
if len(s) == 0 {
return s
}
return strings.Title(s)
},
})
for _, enum := range enums {
file, err := os.Create(fmt.Sprintf("%s_enum.go", enum.Type))
if err != nil {
panic(err)
}
defer file.Close()
err = tmpl.Execute(file, map[string]interface{}{
"PackageName": node.Name.Name,
"Type": enum.Type,
"Enums": enums,
})
if err != nil {
panic(err)
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment