Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Go AST "if err != nil" rewriter
package main
import (
"errors"
"testing"
)
func ThisMightReturnError() error {
return errors.New("I'm an error")
}
func TestMain(t *testing.T) {
err := ThisMightReturnError()
if err != nil {
t.Fatalf("my error: %s", err)
}
}
$ go run rewrite.go
(*ast.File)(0xc0000ec000)({
Doc: (*ast.CommentGroup)(<nil>),
Package: (token.Pos) 1,
Name: (*ast.Ident)(0xc0000ca0c0)(main),
Decls: ([]ast.Decl) (len=3 cap=4) {
(*ast.GenDecl)(0xc0000be080)({
Doc: (*ast.CommentGroup)(<nil>),
TokPos: (token.Pos) 15,
Tok: (token.Token) import,
Lparen: (token.Pos) 22,
Specs: ([]ast.Spec) (len=2 cap=2) {
(*ast.ImportSpec)(0xc0000a0450)({
Doc: (*ast.CommentGroup)(<nil>),
Name: (*ast.Ident)(<nil>),
Path: (*ast.BasicLit)(0xc0000ca0e0)({
ValuePos: (token.Pos) 25,
Kind: (token.Token) STRING,
Value: (string) (len=8) "\"errors\""
}),
Comment: (*ast.CommentGroup)(<nil>),
EndPos: (token.Pos) 0
}),
(*ast.ImportSpec)(0xc0000a0480)({
Doc: (*ast.CommentGroup)(<nil>),
Name: (*ast.Ident)(<nil>),
Path: (*ast.BasicLit)(0xc0000ca100)({
ValuePos: (token.Pos) 35,
Kind: (token.Token) STRING,
Value: (string) (len=9) "\"testing\""
}),
Comment: (*ast.CommentGroup)(<nil>),
EndPos: (token.Pos) 0
})
},
Rparen: (token.Pos) 45
}),
(*ast.FuncDecl)(0xc0000a05a0)({
Doc: (*ast.CommentGroup)(<nil>),
Recv: (*ast.FieldList)(<nil>),
Name: (*ast.Ident)(0xc0000ca140)(ThisMightReturnError),
Type: (*ast.FuncType)(0xc0000ae078)({
Func: (token.Pos) 48,
Params: (*ast.FieldList)(0xc0000a04e0)({
Opening: (token.Pos) 73,
List: ([]*ast.Field) <nil>,
Closing: (token.Pos) 74
}),
Results: (*ast.FieldList)(0xc0000a0510)({
Opening: (token.Pos) 0,
List: ([]*ast.Field) (len=1 cap=1) {
(*ast.Field)(0xc0000be0c0)({
Doc: (*ast.CommentGroup)(<nil>),
Names: ([]*ast.Ident) <nil>,
Type: (*ast.Ident)(0xc0000ca160)(error),
Tag: (*ast.BasicLit)(<nil>),
Comment: (*ast.CommentGroup)(<nil>)
})
},
Closing: (token.Pos) 0
})
}),
Body: (*ast.BlockStmt)(0xc0000a0570)({
Lbrace: (token.Pos) 82,
List: ([]ast.Stmt) (len=1 cap=1) {
(*ast.ReturnStmt)(0xc0000ca1e0)({
Return: (token.Pos) 85,
Results: ([]ast.Expr) (len=1 cap=1) {
(*ast.CallExpr)(0xc0000be100)({
Fun: (*ast.SelectorExpr)(0xc0000ae060)({
X: (*ast.Ident)(0xc0000ca180)(errors),
Sel: (*ast.Ident)(0xc0000ca1a0)(New)
}),
Lparen: (token.Pos) 102,
Args: ([]ast.Expr) (len=1 cap=1) {
(*ast.BasicLit)(0xc0000ca1c0)({
ValuePos: (token.Pos) 103,
Kind: (token.Token) STRING,
Value: (string) (len=14) "\"I'm an error\""
})
},
Ellipsis: (token.Pos) 0,
Rparen: (token.Pos) 117
})
}
})
},
Rbrace: (token.Pos) 119
})
}),
(*ast.FuncDecl)(0xc0000a0750)({
Doc: (*ast.CommentGroup)(<nil>),
Recv: (*ast.FieldList)(<nil>),
Name: (*ast.Ident)(0xc0000ca220)(TestMain),
Type: (*ast.FuncType)(0xc0000ae0d8)({
Func: (token.Pos) 122,
Params: (*ast.FieldList)(0xc0000a0600)({
Opening: (token.Pos) 135,
List: ([]*ast.Field) (len=1 cap=1) {
(*ast.Field)(0xc0000be140)({
Doc: (*ast.CommentGroup)(<nil>),
Names: ([]*ast.Ident) (len=1 cap=1) {
(*ast.Ident)(0xc0000ca240)(t)
},
Type: (*ast.StarExpr)(0xc0000ae0a8)({
Star: (token.Pos) 138,
X: (*ast.SelectorExpr)(0xc0000ae090)({
X: (*ast.Ident)(0xc0000ca260)(testing),
Sel: (*ast.Ident)(0xc0000ca2a0)(T)
})
}),
Tag: (*ast.BasicLit)(<nil>),
Comment: (*ast.CommentGroup)(<nil>)
})
},
Closing: (token.Pos) 148
}),
Results: (*ast.FieldList)(<nil>)
}),
Body: (*ast.BlockStmt)(0xc0000a0720)({
Lbrace: (token.Pos) 150,
List: ([]ast.Stmt) (len=2 cap=2) {
(*ast.AssignStmt)(0xc0000be1c0)({
Lhs: ([]ast.Expr) (len=1 cap=1) {
(*ast.Ident)(0xc0000ca2c0)(err)
},
TokPos: (token.Pos) 157,
Tok: (token.Token) :=,
Rhs: ([]ast.Expr) (len=1 cap=1) {
(*ast.CallExpr)(0xc0000be180)({
Fun: (*ast.Ident)(0xc0000ca2e0)(ThisMightReturnError),
Lparen: (token.Pos) 180,
Args: ([]ast.Expr) <nil>,
Ellipsis: (token.Pos) 0,
Rparen: (token.Pos) 181
})
}
}),
(*ast.IfStmt)(0xc0000be240)({
If: (token.Pos) 184,
Init: (ast.Stmt) <nil>,
Cond: (*ast.BinaryExpr)(0xc0000a0690)({
X: (*ast.Ident)(0xc0000ca300)(err),
OpPos: (token.Pos) 191,
Op: (token.Token) !=,
Y: (*ast.Ident)(0xc0000ca320)(nil)
}),
Body: (*ast.BlockStmt)(0xc0000a06f0)({
Lbrace: (token.Pos) 198,
List: ([]ast.Stmt) (len=1 cap=1) {
(*ast.ExprStmt)(0xc00009e470)({
X: (*ast.CallExpr)(0xc0000be200)({
Fun: (*ast.SelectorExpr)(0xc0000ae0c0)({
X: (*ast.Ident)(0xc0000ca340)(t),
Sel: (*ast.Ident)(0xc0000ca360)(Fatalf)
}),
Lparen: (token.Pos) 210,
Args: ([]ast.Expr) (len=2 cap=2) {
(*ast.BasicLit)(0xc0000ca380)({
ValuePos: (token.Pos) 211,
Kind: (token.Token) STRING,
Value: (string) (len=14) "\"my error: %s\""
}),
(*ast.Ident)(0xc0000ca3a0)(err)
},
Ellipsis: (token.Pos) 0,
Rparen: (token.Pos) 230
})
})
},
Rbrace: (token.Pos) 233
}),
Else: (ast.Stmt) <nil>
})
},
Rbrace: (token.Pos) 235
})
})
},
Scope: (*ast.Scope)(0xc00009e310)(scope 0xc00009e310 {
func TestMain
func ThisMightReturnError
}
),
Imports: ([]*ast.ImportSpec) (len=2 cap=2) {
(*ast.ImportSpec)(0xc0000a0450)({
Doc: (*ast.CommentGroup)(<nil>),
Name: (*ast.Ident)(<nil>),
Path: (*ast.BasicLit)(0xc0000ca0e0)({
ValuePos: (token.Pos) 25,
Kind: (token.Token) STRING,
Value: (string) (len=8) "\"errors\""
}),
Comment: (*ast.CommentGroup)(<nil>),
EndPos: (token.Pos) 0
}),
(*ast.ImportSpec)(0xc0000a0480)({
Doc: (*ast.CommentGroup)(<nil>),
Name: (*ast.Ident)(<nil>),
Path: (*ast.BasicLit)(0xc0000ca100)({
ValuePos: (token.Pos) 35,
Kind: (token.Token) STRING,
Value: (string) (len=9) "\"testing\""
}),
Comment: (*ast.CommentGroup)(<nil>),
EndPos: (token.Pos) 0
})
},
Unresolved: ([]*ast.Ident) (len=4 cap=4) {
(*ast.Ident)(0xc0000ca160)(error),
(*ast.Ident)(0xc0000ca180)(errors),
(*ast.Ident)(0xc0000ca260)(testing),
(*ast.Ident)(0xc0000ca320)(nil)
},
Comments: ([]*ast.CommentGroup) <nil>
})
package main
import (
"errors"
"testing"
)
func ThisMightReturnError() error {
return errors.New("I'm an error")
}
func TestMain(t *testing.T) {
err := ThisMightReturnError()
require.NotNil(t, "my error: %s", err)
}
package main
import (
"go/ast"
"go/parser"
"go/printer"
"go/token"
"log"
"os"
"github.com/davecgh/go-spew/spew"
"golang.org/x/tools/go/ast/astutil"
)
// isIfErrBlock detects folowing pattern:
// (*ast.IfStmt)(0xc0000241c0)({
// If: (token.Pos) 80,
// Init: (ast.Stmt) <nil>,
// Cond: (*ast.BinaryExpr)(0xc000074570)({
// X: (*ast.Ident)(0xc00005e220)(err),
// OpPos: (token.Pos) 87,
// Op: (token.Token) !=,
// Y: (*ast.Ident)(0xc00005e240)(nil)
// }),
func isIfErrBlock(n ast.Node) bool {
// is: if (err != nil)
if ifStmt, ok := n.(*ast.IfStmt); ok {
// is: (err != nil)
if binExpr, ok := ifStmt.Cond.(*ast.BinaryExpr); ok {
// is: !=
if binExpr.Op != token.NEQ {
return false
}
// Check left hand identifier (err)
if ident, ok := binExpr.X.(*ast.Ident); ok {
if ident.Obj == nil {
return false
}
if ident.Obj.Kind != ast.Var || ident.Name != "err" {
return false
}
}
// Check right hand identifier (nil)
if ident, ok := binExpr.Y.(*ast.Ident); ok {
if ident.Obj != nil || ident.Name != "nil" {
return false
}
}
return true
}
}
return false
}
func isErrBody(n ast.Node) (bool, []ast.Expr) {
// check if it's really an if statement
ifStmt, ok := n.(*ast.IfStmt)
if !ok {
return false, nil
}
// check if it has only one item in the block
if len(ifStmt.Body.List) != 1 {
return false, nil
}
stmt := ifStmt.Body.List[0]
// Cast expression to get the call
expr, ok := stmt.(*ast.ExprStmt)
if !ok {
return false, nil
}
call, ok := expr.X.(*ast.CallExpr)
if !ok {
return false, nil
}
// Get the function
fun, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return false, nil
}
if fun.Sel.Name != "Fatalf" {
return false, nil
}
return true, call.Args
}
func main() {
// Parse file
fs := token.NewFileSet()
f, err := parser.ParseFile(fs, "main_test.go", nil, parser.AllErrors)
if err != nil {
log.Fatal(err)
}
// Dump data
spew.Dump(f)
// Walk through AST
astutil.Apply(f, func(cr *astutil.Cursor) bool {
// Check if it is: if err != nil
if !isIfErrBlock(cr.Node()) {
return true
}
// Check if body contains only one statement and get args for it
sigRet, args := isErrBody(cr.Node())
if !sigRet {
return true
}
// Replace values
cr.Replace(&ast.ExprStmt{
X: &ast.CallExpr{
Fun: ast.NewIdent("require.NotNil"),
Args: append([]ast.Expr{ast.NewIdent("t")}, args...),
},
})
return false
}, nil)
// Print result
printer.Fprint(os.Stdout, fs, f)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment