Skip to content

Instantly share code, notes, and snippets.

@alsritter
Last active December 12, 2021 16:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alsritter/de97e60119daf1a9aac933bef132ca30 to your computer and use it in GitHub Desktop.
Save alsritter/de97e60119daf1a9aac933bef132ca30 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"errors"
"fmt"
errx "github.com/pkg/errors"
"go.uber.org/zap"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/token"
"io"
"log"
"os"
"strings"
)
// alias
type (
Logger = zap.Logger
SugaredLogger = zap.SugaredLogger
)
func main() {
err := run()
if err != nil {
log.Fatalln(err)
}
}
func run() error {
dir, err := getImportPkg("go.uber.org/zap")
if err != nil {
return errx.WithStack(err)
}
log.Printf("dir: %+v", dir)
pkg, err := parseDir(dir, "zap")
if err != nil {
return errx.WithStack(err)
}
funcs, err := walkAst(pkg)
if err != nil {
return errx.WithStack(err)
}
err = writeGoFile(os.Stdout, funcs)
if err != nil {
return errx.WithStack(err)
}
return nil
}
// ==================================================================
func walkAst(node ast.Node) ([]ast.Decl, error) {
v := &visitor{}
ast.Walk(v, node)
log.Printf("funcs len: %d", len(v.funcs))
var decls []ast.Decl
for _, v := range v.funcs {
decls = append(decls, v)
}
return decls, nil
}
// getImportPkg 通过 go/build 包获取其在 gomod 中的路径,不用手动填写:
func getImportPkg(pkg string) (string, error) {
p, err := build.Import(pkg, "", build.FindOnly)
if err != nil {
return "", err
}
return p.Dir, err
}
// parseDir 解析整个目录,找到对应包名的的 *ast.Package
func parseDir(dir, pkgName string) (*ast.Package, error) {
// 返回的是一个 package name -> package 的 Map
pkgMap, err := parser.ParseDir(
token.NewFileSet(),
dir,
func(info os.FileInfo) bool {
// skip go-test
return !strings.Contains(info.Name(), "_test.go")
},
parser.Mode(0), // no comment
)
if err != nil {
return nil, errx.WithStack(err)
}
pkg, ok := pkgMap[pkgName]
if !ok {
err := errors.New("not found")
return nil, errx.WithStack(err)
}
return pkg, nil
}
// 定义一个 visitor 结构去实现 ast.Visitor 接口
type visitor struct {
funcs []*ast.FuncDecl
}
// Visit 遍历 ast,找到 SugaredLogger 的所有 Exported 方法:
func (v *visitor) Visit(node ast.Node) ast.Visitor {
switch n := node.(type) {
// 只需处理 FuncDecl
case *ast.FuncDecl:
if n.Recv == nil ||
!n.Name.IsExported() || // 判断方法名称是否为大写(公开
len(n.Recv.List) != 1 { // 取得当前接收者下有多少个方法
return nil
}
t, ok := n.Recv.List[0].Type.(*ast.StarExpr)
if !ok {
return nil
}
if t.X.(*ast.Ident).String() != "SugaredLogger" {
return nil
}
log.Printf("func name: %s", n.Name.String())
v.funcs = append(v.funcs, rewriteFunc(n))
}
return v
}
// rewriteFunc 修改函数的属性
func rewriteFunc(fn *ast.FuncDecl) *ast.FuncDecl {
fn.Recv = nil // 将方法接收者置空,变为函数。
fnName := fn.Name.String()
var args []string
// fn.Type 表示当前当前函数的属性(函数位置、函数的参数列表、函数的返回值)
for _, field := range fn.Type.Params.List {
// 因为 field 可以代表 struct 类型、interface 里面的 method 列表、或者一个参数签名,因此这里的 Names 是一个 List
for _, id := range field.Names {
// 取得参数名
idStr := id.String()
_, ok := field.Type.(*ast.Ellipsis) // 判断当前 field 是否为可变参数
if ok {
// Ellipsis args
idStr += "..."
}
args = append(args, idStr)
}
}
// 函数 body 改为调用 zap.S() 方法。
exprStr := fmt.Sprintf(`zap02.S().%s(%s)`, fnName, strings.Join(args, ","))
expr, err := parser.ParseExpr(exprStr) // 生成这个表达式(方法)的 AST 树
if err != nil {
panic(err)
}
var body []ast.Stmt // 所有 AST Node 都实现了 Stmt
if fn.Type.Results != nil {
body = []ast.Stmt{
// 如果有返回值则需要 return 语句。
&ast.ReturnStmt{
// Return:
Results: []ast.Expr{expr}, // 如果有返回值则需要 return 语句。
},
}
} else {
body = []ast.Stmt{
&ast.ExprStmt{
X: expr,
},
}
}
fn.Body.List = body
return fn
}
// ast 转化为 go 代码
func astToGo(dst *bytes.Buffer, node interface{}) error {
addNewline := func() {
err := dst.WriteByte('\n') // add newline
if err != nil {
log.Panicln(err)
}
}
addNewline()
// 单个 func 的 ast 转化为 go 代码,使用 go/format 包:
err := format.Node(dst, token.NewFileSet(), node)
if err != nil {
return err
}
addNewline()
return nil
}
// writeGoFile 拼装成完整 go file
func writeGoFile(wr io.Writer, funcs []ast.Decl) error {
// 输出Go代码
header := `// Code generated by log-gen. DO NOT EDIT.
package log
import zap02 "go.uber.org/zap"
`
buffer := bytes.NewBufferString(header)
for _, fn := range funcs {
err := astToGo(buffer, fn)
if err != nil {
return errx.WithStack(err)
}
}
_, err := wr.Write(buffer.Bytes())
return err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment