Skip to content

Instantly share code, notes, and snippets.

@mmcloughlin
Last active January 5, 2019 22:53
Show Gist options
  • Save mmcloughlin/2f6ff496978ef57efce13eab84c69cd5 to your computer and use it in GitHub Desktop.
Save mmcloughlin/2f6ff496978ef57efce13eab84c69cd5 to your computer and use it in GitHub Desktop.
Generate global versions of struct methods
package main
import (
"bytes"
"flag"
"fmt"
"go/format"
"go/types"
"log"
"os"
"strings"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/go/packages"
)
var (
importpath = flag.String("pkg", "", "package to load")
name = flag.String("var", "", "name of the global variable")
)
func LookupType(pkgs []*packages.Package, name string) types.Object {
for _, pkg := range pkgs {
s := pkg.Types.Scope()
obj := s.Lookup(name)
if obj != nil {
return obj
}
}
return nil
}
func ExportedMethods(t types.Type) []*types.Func {
ms := typeutil.IntuitiveMethodSet(t, nil)
var fs []*types.Func
for _, m := range ms {
obj := m.Obj()
if !obj.Exported() {
continue
}
f, ok := m.Obj().(*types.Func)
if ok {
fs = append(fs, f)
}
}
return fs
}
func ImportedQualifier(pkg *types.Package) string {
return pkg.Name()
}
func WriteFunction(buf *bytes.Buffer, f *types.Func, varname string) {
fmt.Fprintf(buf, "// %s calls %s on the global context.\n", f.Name(), f.Name())
fmt.Fprintf(buf, "func %s", f.Name())
s := f.Type().(*types.Signature)
types.WriteSignature(buf, s, ImportedQualifier)
fmt.Fprint(buf, " { ")
if s.Results().Len() > 0 {
fmt.Fprint(buf, "return ")
}
var paramnames []string
for i := 0; i < s.Params().Len(); i++ {
paramnames = append(paramnames, s.Params().At(i).Name())
}
fmt.Fprintf(buf, "%s.%s(%s)", varname, f.Name(), strings.Join(paramnames, ", "))
fmt.Fprint(buf, " }\n\n")
}
func main() {
flag.Parse()
// Load the package.
cfg := &packages.Config{
Mode: packages.LoadTypes,
}
pkgs, err := packages.Load(cfg, *importpath)
if err != nil {
log.Fatal(err)
}
// Lookup the type.
obj := LookupType(pkgs, *name)
if obj == nil {
log.Fatal("could not find type")
}
fs := ExportedMethods(obj.Type())
if len(fs) == 0 {
log.Fatal("no exported methods found")
}
buf := bytes.NewBuffer(nil)
for _, f := range fs {
WriteFunction(buf, f, *name)
}
src, err := format.Source(buf.Bytes())
if err != nil {
log.Fatal(err)
}
os.Stdout.Write(src)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment