-
-
Save alsritter/de97e60119daf1a9aac933bef132ca30 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"errors" | |
"fmt" | |
errx "github.com/pkg/errors" | |
"go.uber.org/zap" | |
"go/ast" | |
"go/build" | |
"go/format" | |
"go/parser" | |
"go/token" | |
"io" | |
"log" | |
"os" | |
"strings" | |
) | |
// alias | |
type ( | |
Logger = zap.Logger | |
SugaredLogger = zap.SugaredLogger | |
) | |
func main() { | |
err := run() | |
if err != nil { | |
log.Fatalln(err) | |
} | |
} | |
func run() error { | |
dir, err := getImportPkg("go.uber.org/zap") | |
if err != nil { | |
return errx.WithStack(err) | |
} | |
log.Printf("dir: %+v", dir) | |
pkg, err := parseDir(dir, "zap") | |
if err != nil { | |
return errx.WithStack(err) | |
} | |
funcs, err := walkAst(pkg) | |
if err != nil { | |
return errx.WithStack(err) | |
} | |
err = writeGoFile(os.Stdout, funcs) | |
if err != nil { | |
return errx.WithStack(err) | |
} | |
return nil | |
} | |
// ================================================================== | |
func walkAst(node ast.Node) ([]ast.Decl, error) { | |
v := &visitor{} | |
ast.Walk(v, node) | |
log.Printf("funcs len: %d", len(v.funcs)) | |
var decls []ast.Decl | |
for _, v := range v.funcs { | |
decls = append(decls, v) | |
} | |
return decls, nil | |
} | |
// getImportPkg 通过 go/build 包获取其在 gomod 中的路径,不用手动填写: | |
func getImportPkg(pkg string) (string, error) { | |
p, err := build.Import(pkg, "", build.FindOnly) | |
if err != nil { | |
return "", err | |
} | |
return p.Dir, err | |
} | |
// parseDir 解析整个目录,找到对应包名的的 *ast.Package | |
func parseDir(dir, pkgName string) (*ast.Package, error) { | |
// 返回的是一个 package name -> package 的 Map | |
pkgMap, err := parser.ParseDir( | |
token.NewFileSet(), | |
dir, | |
func(info os.FileInfo) bool { | |
// skip go-test | |
return !strings.Contains(info.Name(), "_test.go") | |
}, | |
parser.Mode(0), // no comment | |
) | |
if err != nil { | |
return nil, errx.WithStack(err) | |
} | |
pkg, ok := pkgMap[pkgName] | |
if !ok { | |
err := errors.New("not found") | |
return nil, errx.WithStack(err) | |
} | |
return pkg, nil | |
} | |
// 定义一个 visitor 结构去实现 ast.Visitor 接口 | |
type visitor struct { | |
funcs []*ast.FuncDecl | |
} | |
// Visit 遍历 ast,找到 SugaredLogger 的所有 Exported 方法: | |
func (v *visitor) Visit(node ast.Node) ast.Visitor { | |
switch n := node.(type) { | |
// 只需处理 FuncDecl | |
case *ast.FuncDecl: | |
if n.Recv == nil || | |
!n.Name.IsExported() || // 判断方法名称是否为大写(公开 | |
len(n.Recv.List) != 1 { // 取得当前接收者下有多少个方法 | |
return nil | |
} | |
t, ok := n.Recv.List[0].Type.(*ast.StarExpr) | |
if !ok { | |
return nil | |
} | |
if t.X.(*ast.Ident).String() != "SugaredLogger" { | |
return nil | |
} | |
log.Printf("func name: %s", n.Name.String()) | |
v.funcs = append(v.funcs, rewriteFunc(n)) | |
} | |
return v | |
} | |
// rewriteFunc 修改函数的属性 | |
func rewriteFunc(fn *ast.FuncDecl) *ast.FuncDecl { | |
fn.Recv = nil // 将方法接收者置空,变为函数。 | |
fnName := fn.Name.String() | |
var args []string | |
// fn.Type 表示当前当前函数的属性(函数位置、函数的参数列表、函数的返回值) | |
for _, field := range fn.Type.Params.List { | |
// 因为 field 可以代表 struct 类型、interface 里面的 method 列表、或者一个参数签名,因此这里的 Names 是一个 List | |
for _, id := range field.Names { | |
// 取得参数名 | |
idStr := id.String() | |
_, ok := field.Type.(*ast.Ellipsis) // 判断当前 field 是否为可变参数 | |
if ok { | |
// Ellipsis args | |
idStr += "..." | |
} | |
args = append(args, idStr) | |
} | |
} | |
// 函数 body 改为调用 zap.S() 方法。 | |
exprStr := fmt.Sprintf(`zap02.S().%s(%s)`, fnName, strings.Join(args, ",")) | |
expr, err := parser.ParseExpr(exprStr) // 生成这个表达式(方法)的 AST 树 | |
if err != nil { | |
panic(err) | |
} | |
var body []ast.Stmt // 所有 AST Node 都实现了 Stmt | |
if fn.Type.Results != nil { | |
body = []ast.Stmt{ | |
// 如果有返回值则需要 return 语句。 | |
&ast.ReturnStmt{ | |
// Return: | |
Results: []ast.Expr{expr}, // 如果有返回值则需要 return 语句。 | |
}, | |
} | |
} else { | |
body = []ast.Stmt{ | |
&ast.ExprStmt{ | |
X: expr, | |
}, | |
} | |
} | |
fn.Body.List = body | |
return fn | |
} | |
// ast 转化为 go 代码 | |
func astToGo(dst *bytes.Buffer, node interface{}) error { | |
addNewline := func() { | |
err := dst.WriteByte('\n') // add newline | |
if err != nil { | |
log.Panicln(err) | |
} | |
} | |
addNewline() | |
// 单个 func 的 ast 转化为 go 代码,使用 go/format 包: | |
err := format.Node(dst, token.NewFileSet(), node) | |
if err != nil { | |
return err | |
} | |
addNewline() | |
return nil | |
} | |
// writeGoFile 拼装成完整 go file | |
func writeGoFile(wr io.Writer, funcs []ast.Decl) error { | |
// 输出Go代码 | |
header := `// Code generated by log-gen. DO NOT EDIT. | |
package log | |
import zap02 "go.uber.org/zap" | |
` | |
buffer := bytes.NewBufferString(header) | |
for _, fn := range funcs { | |
err := astToGo(buffer, fn) | |
if err != nil { | |
return errx.WithStack(err) | |
} | |
} | |
_, err := wr.Write(buffer.Bytes()) | |
return err | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment