Skip to content

Instantly share code, notes, and snippets.

@shezadkhan137
Created June 7, 2020 20:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shezadkhan137/49ac836992b26fe422bc1baeb029ba86 to your computer and use it in GitHub Desktop.
Save shezadkhan137/49ac836992b26fe422bc1baeb029ba86 to your computer and use it in GitHub Desktop.
Go Module-aware AST rewriting
package main
import (
"bytes"
"go/ast"
"go/printer"
"go/types"
"io/ioutil"
"os"
"github.com/fatih/astrewrite"
"github.com/pkg/errors"
"golang.org/x/tools/go/packages"
)
func main() {
err := load(os.Args[1])
if err != nil {
panic(err)
}
}
func load(packageName string) error {
config := &packages.Config{
Mode: packages.LoadSyntax,
}
pkgs, err := packages.Load(config, packageName)
if err != nil {
return errors.Wrapf(err, "Error loading package %s", packageName)
}
for _, pkg := range pkgs {
rewriter := &Rewriter{pkg: pkg}
err := rewriter.Rewrite()
if err != nil {
return err
}
}
return nil
}
type scope struct {
s *types.Scope
i *ast.Ident
}
type Rewriter struct {
pkg *packages.Package
contextScopes []*scope
fileHasLogReWritten bool
}
func (r *Rewriter) Rewrite() error {
for i, file := range r.pkg.GoFiles {
astrewrite.Walk(r.pkg.Syntax[i], r.visitForContextFuncs)
rewritten := astrewrite.Walk(r.pkg.Syntax[i], r.visitForLogStatements)
var buf bytes.Buffer
printer.Fprint(&buf, r.pkg.Fset, rewritten)
ioutil.WriteFile(file, buf.Bytes(), 0644)
}
return nil
}
func (r *Rewriter) visitForContextFuncs(n ast.Node) (ast.Node, bool) {
switch v := n.(type) {
case *ast.FuncDecl:
if scope := r.checkHasContext(v); scope != nil {
r.contextScopes = append(r.contextScopes, scope)
}
}
return n, true
}
func (r *Rewriter) checkHasContext(node *ast.FuncDecl) *scope {
if id, ok := hasContextParam(node.Type.Params); ok {
funcScope := r.pkg.TypesInfo.Scopes[node.Type]
return &scope{s: funcScope, i: id}
}
return nil
}
func hasContextParam(params *ast.FieldList) (*ast.Ident, bool) {
if len(params.List) > 0 {
for _, param := range params.List {
paramTypeName := getParamName(param.Type)
if paramTypeName == "context.Context" {
if len(param.Names) > 0 {
return param.Names[0], true
}
}
}
}
// no params
return nil, false
}
func getParamName(n ast.Expr) string {
s, ok := n.(*ast.SelectorExpr)
if !ok {
return ""
}
i, ok := s.X.(*ast.Ident)
if !ok {
return ""
}
return i.Name + "." + s.Sel.Name
}
func (r *Rewriter) visitForLogStatements(n ast.Node) (ast.Node, bool) {
switch v := n.(type) {
case *ast.CallExpr:
if id, ok := r.checkHasLogStatement(v); ok {
r.fileHasLogReWritten = true
if scope, ok := r.hasContextInScope(id); ok {
return r.rewriteLog(v, scope), true
} else {
return r.rewriteLog(v, nil), true
}
}
}
return n, true
}
func (r *Rewriter) checkHasLogStatement(n *ast.CallExpr) (*ast.Ident, bool) {
s, ok := n.Fun.(*ast.SelectorExpr)
if !ok {
return nil, false
}
id := s.Sel
if id != nil && !r.pkg.TypesInfo.Types[id].IsType() {
if id.Name != "Printf" {
return nil, false
}
use, ok := r.pkg.TypesInfo.Uses[id]
if ok {
if use.Pkg().Path() == "log" {
return id, true
}
}
}
return nil, false
}
func (r *Rewriter) hasContextInScope(id *ast.Ident) (*scope, bool) {
for _, s := range r.contextScopes {
if s.s.Contains(id.Pos()) {
return s, true
}
}
return nil, false
}
func (r *Rewriter) rewriteLog(n *ast.CallExpr, scope *scope) ast.Node {
var context ast.Expr
if scope != nil {
context = scope.i
} else {
context = &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{
Name: "context",
},
Sel: &ast.Ident{
Name: "TODO",
},
},
}
}
newArgs := []ast.Expr{context}
newArgs = append(newArgs, n.Args...)
n.Args = newArgs
return n
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment