Skip to content

Instantly share code, notes, and snippets.

@chrisseto
Created September 20, 2023 19:49
Show Gist options
  • Save chrisseto/cd5f94c7e70cbbccd9df05788e4b1cb8 to your computer and use it in GitHub Desktop.
Save chrisseto/cd5f94c7e70cbbccd9df05788e4b1cb8 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"fmt"
"go/ast"
"go/token"
"go/types"
"sort"
"strings"
"golang.org/x/tools/go/packages"
)
func MethodByName(t *types.Named, name string) *types.Func {
for i := 0; i < t.NumMethods(); i++ {
method := t.Method(i)
if method.Name() == name {
return method
}
}
return nil
}
func Filter[T any](items []T, filters ...func(T) bool) []T {
all := func(item T) bool {
for _, fn := range filters {
if !fn(item) {
return false
}
}
return true
}
i := 0
out := make([]T, len(items))
for _, el := range items {
if all(el) {
out[i] = el
i++
}
}
return out[:i]
}
func Objects(scope *types.Scope) []types.Object {
out := make([]types.Object, len(scope.Names()))
for i, name := range scope.Names() {
out[i] = scope.Lookup(name)
}
return out
}
func IsExported(obj types.Object) bool {
return obj.Exported()
}
func IsStruct(obj types.Object) bool {
_, ok := obj.Type().Underlying().(*types.Struct)
return ok
}
func Implements(iface *types.Interface) func(types.Object) bool {
return func(obj types.Object) bool {
ptr := types.NewPointer(obj.Type())
return types.Implements(obj.Type(), iface) || types.Implements(ptr, iface)
}
}
func Contains(n ast.Node, pos token.Pos) bool {
return n.Pos() <= pos && n.End() >= pos
}
// ToAst returns the ast.Node or specified type that is most closely associated
// with the position pos.
func ToAst[T ast.Node](pkg *packages.Package, pos token.Pos) T {
var node T
found := false
for _, f := range pkg.Syntax {
ast.Inspect(f, func(n ast.Node) bool {
if found || n == nil || !Contains(n, pos) {
return false
}
if cast, ok := n.(T); ok {
node = cast
found = true
return false
}
return true
})
if found {
break
}
}
return node
}
// IsDDL does its best to determine if an implementor of Statement returns the
// DDL constant from its StatementReturnType method.
func IsDDL(pkg *packages.Package) func(types.Object) bool {
return func(obj types.Object) bool {
method := MethodByName(obj.Type().(*types.Named), "StatementReturnType")
fn := ToAst[*ast.FuncDecl](pkg, method.Pos())
isDDL := false
ast.Inspect(fn.Body, func(n ast.Node) bool {
if ident, ok := n.(*ast.Ident); ok && ident.Name == "DDL" {
isDDL = true
return false
}
return true
})
return isDDL
}
}
func main() {
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedFiles | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
}, "./pkg/sql/sem/tree")
if err != nil {
panic(err)
}
pkg := pkgs[0]
statementIface := pkg.Types.Scope().Lookup("Statement").Type().(*types.Named).Underlying().(*types.Interface)
nodeFormatterIface := pkg.Types.Scope().Lookup("NodeFormatter").Type().(*types.Named).Underlying().(*types.Interface)
// Find DDL Statements by finding all public structs that implement Statement
// which return DDL from StatementReturnType.
objs := Objects(pkg.Types.Scope())
publicStructs := Filter(objs, IsExported, IsStruct)
statements := Filter(publicStructs, Implements(statementIface))
ddl := Filter(statements, IsDDL(pkg))
// Find all ddl subcommands, like AlterTableAddColumn, by finding all public
// structs that implement NodeFormatter who's name starts with a DDL
// statement's name.
// This is pretty loose and might find false positives or skip subcommands.
nodeFormatters := Filter(publicStructs, Implements(nodeFormatterIface))
ddlSubCommands := Filter(nodeFormatters, func(obj types.Object) bool {
for _, stmt := range ddl {
if obj.Name() == stmt.Name() {
return false
}
}
for _, stmt := range ddl {
if strings.HasPrefix(obj.Name(), stmt.Name()) {
return true
}
}
return false
})
// Knit together the final list. We have to do some manual post processing
// any how, so don't worry about duplicates.
ddl = append(ddl, ddlSubCommands...)
sort.Slice(ddl, func(i, j int) bool {
return ddl[i].Name() < ddl[j].Name()
})
for _, obj := range ddl {
name := []byte(obj.Name())
name[0] = bytes.ToLower(name[:1])[0]
fmt.Printf("%s\n", name)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment