Skip to content

Instantly share code, notes, and snippets.

@vbatts
Last active October 24, 2017 23:55
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vbatts/4723845a25f50ba0dbadd5f6f5a69fe7 to your computer and use it in GitHub Desktop.
Save vbatts/4723845a25f50ba0dbadd5f6f5a69fe7 to your computer and use it in GitHub Desktop.
[golang] Looking for struct pointers in interface function signatures
package main
// spurred from discussion around https://github.com/kubernetes/kubernetes/pull/54257#issuecomment-338274869
import (
"flag"
"fmt"
"go/ast"
"go/build"
"go/parser"
"go/token"
"io/ioutil"
"log"
"path/filepath"
"strings"
"github.com/davecgh/go-spew/spew"
"github.com/hashicorp/errwrap"
multierror "github.com/hashicorp/go-multierror"
)
func main() {
flag.Parse()
for _, arg := range flag.Args() {
if err := findInterfacePointer(arg); err != nil {
log.Fatal(err)
}
}
}
var flDebug = flag.Bool("D", false, "enable debug output")
type ErrInterfacePointer struct {
Name string
Pos token.Pos
}
func (e ErrInterfacePointer) Error() string {
return fmt.Sprintf("do not use pointers to structs in interface functions: %s (%d)",
e.Name,
e.Pos)
}
func findInterfacePointer(pkgname string) error {
context := build.Default
pkg, err := context.Import(pkgname, "", 0)
if err != nil {
return err
}
var result error
var names []string
names = append(names, pkg.GoFiles...)
names = append(names, pkg.CgoFiles...)
names = append(names, pkg.TestGoFiles...) // These are also in the "foo" package.
names = append(names, pkg.SFiles...)
fs := token.NewFileSet()
dir := filepath.Join(pkg.SrcRoot, pkg.ImportPath)
for _, name := range names {
if strings.HasSuffix(name, ".go") {
data, err := ioutil.ReadFile(filepath.Join(dir, name))
if err != nil {
result = multierror.Append(result, errwrap.Wrapf(fmt.Sprintf("[%s:%s] {{err}}", pkgname, name), err))
}
parsedFile, err := parser.ParseFile(fs, name, data, 0)
if err != nil {
result = multierror.Append(result, errwrap.Wrapf(fmt.Sprintf("[%s:%s] {{err}}", pkgname, name), err))
}
for _, decl := range parsedFile.Decls {
d, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
if len(d.Specs) == 0 {
continue
}
for _, spec := range d.Specs {
s, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
t, ok := s.Type.(*ast.InterfaceType)
if !ok {
continue
}
for _, method := range t.Methods.List {
if err := checkForPointer(method); err != nil {
f := fs.File(method.Pos())
if f != nil {
result = multierror.Append(result, errwrap.Wrapf(fmt.Sprintf("[%s/%s:%d] {{err}}", pkgname, name, f.Line(method.Pos())), err))
} else {
result = multierror.Append(result, errwrap.Wrapf(fmt.Sprintf("[%s/%s] {{err}}", pkgname, name), err))
}
}
}
}
}
}
}
return result
}
func checkForPointer(m *ast.Field) error {
t, ok := m.Type.(*ast.FuncType)
if !ok {
return nil
}
if t.Params.List == nil {
return nil
}
for _, field := range t.Params.List {
star, ok := field.Type.(*ast.StarExpr)
if !ok {
// the field is not a pointer
continue
}
ident, ok := star.X.(*ast.Ident)
if !ok {
continue
}
if ident.Obj == nil {
// the pointer is to a builtin primitive
continue
}
if *flDebug {
spew.Dump(ident.Name)
spew.Dump(ident.Obj)
}
return ErrInterfacePointer{ident.Name, ident.NamePos}
}
return nil
}
var testBadCode = `
package pointertest
type Person struct {
Name string
Age int
Address Address
}
type Address struct {
Street string
ZipCode string
}
type Personer interface {
Hello()
Welcome(*Person) error
Hi(name string)
Cheers(string, *Person, *int, int64)
}
`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment