Skip to content

Instantly share code, notes, and snippets.

@djboris9
Created March 26, 2021 14:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save djboris9/f58c7bf19472dab829a858fd426685e9 to your computer and use it in GitHub Desktop.
Save djboris9/f58c7bf19472dab829a858fd426685e9 to your computer and use it in GitHub Desktop.
Go AST "if err != nil" rewriter
package main
import (
"errors"
"testing"
)
func ThisMightReturnError() error {
return errors.New("I'm an error")
}
func TestMain(t *testing.T) {
err := ThisMightReturnError()
if err != nil {
t.Fatalf("my error: %s", err)
}
}
$ go run rewrite.go
(*ast.File)(0xc0000ec000)({
Doc: (*ast.CommentGroup)(<nil>),
Package: (token.Pos) 1,
Name: (*ast.Ident)(0xc0000ca0c0)(main),
Decls: ([]ast.Decl) (len=3 cap=4) {
(*ast.GenDecl)(0xc0000be080)({
Doc: (*ast.CommentGroup)(<nil>),
TokPos: (token.Pos) 15,
Tok: (token.Token) import,
Lparen: (token.Pos) 22,
Specs: ([]ast.Spec) (len=2 cap=2) {
(*ast.ImportSpec)(0xc0000a0450)({
Doc: (*ast.CommentGroup)(<nil>),
Name: (*ast.Ident)(<nil>),
Path: (*ast.BasicLit)(0xc0000ca0e0)({
ValuePos: (token.Pos) 25,
Kind: (token.Token) STRING,
Value: (string) (len=8) "\"errors\""
}),
Comment: (*ast.CommentGroup)(<nil>),
EndPos: (token.Pos) 0
}),
(*ast.ImportSpec)(0xc0000a0480)({
Doc: (*ast.CommentGroup)(<nil>),
Name: (*ast.Ident)(<nil>),
Path: (*ast.BasicLit)(0xc0000ca100)({
ValuePos: (token.Pos) 35,
Kind: (token.Token) STRING,
Value: (string) (len=9) "\"testing\""
}),
Comment: (*ast.CommentGroup)(<nil>),
EndPos: (token.Pos) 0
})
},
Rparen: (token.Pos) 45
}),
(*ast.FuncDecl)(0xc0000a05a0)({
Doc: (*ast.CommentGroup)(<nil>),
Recv: (*ast.FieldList)(<nil>),
Name: (*ast.Ident)(0xc0000ca140)(ThisMightReturnError),
Type: (*ast.FuncType)(0xc0000ae078)({
Func: (token.Pos) 48,
Params: (*ast.FieldList)(0xc0000a04e0)({
Opening: (token.Pos) 73,
List: ([]*ast.Field) <nil>,
Closing: (token.Pos) 74
}),
Results: (*ast.FieldList)(0xc0000a0510)({
Opening: (token.Pos) 0,
List: ([]*ast.Field) (len=1 cap=1) {
(*ast.Field)(0xc0000be0c0)({
Doc: (*ast.CommentGroup)(<nil>),
Names: ([]*ast.Ident) <nil>,
Type: (*ast.Ident)(0xc0000ca160)(error),
Tag: (*ast.BasicLit)(<nil>),
Comment: (*ast.CommentGroup)(<nil>)
})
},
Closing: (token.Pos) 0
})
}),
Body: (*ast.BlockStmt)(0xc0000a0570)({
Lbrace: (token.Pos) 82,
List: ([]ast.Stmt) (len=1 cap=1) {
(*ast.ReturnStmt)(0xc0000ca1e0)({
Return: (token.Pos) 85,
Results: ([]ast.Expr) (len=1 cap=1) {
(*ast.CallExpr)(0xc0000be100)({
Fun: (*ast.SelectorExpr)(0xc0000ae060)({
X: (*ast.Ident)(0xc0000ca180)(errors),
Sel: (*ast.Ident)(0xc0000ca1a0)(New)
}),
Lparen: (token.Pos) 102,
Args: ([]ast.Expr) (len=1 cap=1) {
(*ast.BasicLit)(0xc0000ca1c0)({
ValuePos: (token.Pos) 103,
Kind: (token.Token) STRING,
Value: (string) (len=14) "\"I'm an error\""
})
},
Ellipsis: (token.Pos) 0,
Rparen: (token.Pos) 117
})
}
})
},
Rbrace: (token.Pos) 119
})
}),
(*ast.FuncDecl)(0xc0000a0750)({
Doc: (*ast.CommentGroup)(<nil>),
Recv: (*ast.FieldList)(<nil>),
Name: (*ast.Ident)(0xc0000ca220)(TestMain),
Type: (*ast.FuncType)(0xc0000ae0d8)({
Func: (token.Pos) 122,
Params: (*ast.FieldList)(0xc0000a0600)({
Opening: (token.Pos) 135,
List: ([]*ast.Field) (len=1 cap=1) {
(*ast.Field)(0xc0000be140)({
Doc: (*ast.CommentGroup)(<nil>),
Names: ([]*ast.Ident) (len=1 cap=1) {
(*ast.Ident)(0xc0000ca240)(t)
},
Type: (*ast.StarExpr)(0xc0000ae0a8)({
Star: (token.Pos) 138,
X: (*ast.SelectorExpr)(0xc0000ae090)({
X: (*ast.Ident)(0xc0000ca260)(testing),
Sel: (*ast.Ident)(0xc0000ca2a0)(T)
})
}),
Tag: (*ast.BasicLit)(<nil>),
Comment: (*ast.CommentGroup)(<nil>)
})
},
Closing: (token.Pos) 148
}),
Results: (*ast.FieldList)(<nil>)
}),
Body: (*ast.BlockStmt)(0xc0000a0720)({
Lbrace: (token.Pos) 150,
List: ([]ast.Stmt) (len=2 cap=2) {
(*ast.AssignStmt)(0xc0000be1c0)({
Lhs: ([]ast.Expr) (len=1 cap=1) {
(*ast.Ident)(0xc0000ca2c0)(err)
},
TokPos: (token.Pos) 157,
Tok: (token.Token) :=,
Rhs: ([]ast.Expr) (len=1 cap=1) {
(*ast.CallExpr)(0xc0000be180)({
Fun: (*ast.Ident)(0xc0000ca2e0)(ThisMightReturnError),
Lparen: (token.Pos) 180,
Args: ([]ast.Expr) <nil>,
Ellipsis: (token.Pos) 0,
Rparen: (token.Pos) 181
})
}
}),
(*ast.IfStmt)(0xc0000be240)({
If: (token.Pos) 184,
Init: (ast.Stmt) <nil>,
Cond: (*ast.BinaryExpr)(0xc0000a0690)({
X: (*ast.Ident)(0xc0000ca300)(err),
OpPos: (token.Pos) 191,
Op: (token.Token) !=,
Y: (*ast.Ident)(0xc0000ca320)(nil)
}),
Body: (*ast.BlockStmt)(0xc0000a06f0)({
Lbrace: (token.Pos) 198,
List: ([]ast.Stmt) (len=1 cap=1) {
(*ast.ExprStmt)(0xc00009e470)({
X: (*ast.CallExpr)(0xc0000be200)({
Fun: (*ast.SelectorExpr)(0xc0000ae0c0)({
X: (*ast.Ident)(0xc0000ca340)(t),
Sel: (*ast.Ident)(0xc0000ca360)(Fatalf)
}),
Lparen: (token.Pos) 210,
Args: ([]ast.Expr) (len=2 cap=2) {
(*ast.BasicLit)(0xc0000ca380)({
ValuePos: (token.Pos) 211,
Kind: (token.Token) STRING,
Value: (string) (len=14) "\"my error: %s\""
}),
(*ast.Ident)(0xc0000ca3a0)(err)
},
Ellipsis: (token.Pos) 0,
Rparen: (token.Pos) 230
})
})
},
Rbrace: (token.Pos) 233
}),
Else: (ast.Stmt) <nil>
})
},
Rbrace: (token.Pos) 235
})
})
},
Scope: (*ast.Scope)(0xc00009e310)(scope 0xc00009e310 {
func TestMain
func ThisMightReturnError
}
),
Imports: ([]*ast.ImportSpec) (len=2 cap=2) {
(*ast.ImportSpec)(0xc0000a0450)({
Doc: (*ast.CommentGroup)(<nil>),
Name: (*ast.Ident)(<nil>),
Path: (*ast.BasicLit)(0xc0000ca0e0)({
ValuePos: (token.Pos) 25,
Kind: (token.Token) STRING,
Value: (string) (len=8) "\"errors\""
}),
Comment: (*ast.CommentGroup)(<nil>),
EndPos: (token.Pos) 0
}),
(*ast.ImportSpec)(0xc0000a0480)({
Doc: (*ast.CommentGroup)(<nil>),
Name: (*ast.Ident)(<nil>),
Path: (*ast.BasicLit)(0xc0000ca100)({
ValuePos: (token.Pos) 35,
Kind: (token.Token) STRING,
Value: (string) (len=9) "\"testing\""
}),
Comment: (*ast.CommentGroup)(<nil>),
EndPos: (token.Pos) 0
})
},
Unresolved: ([]*ast.Ident) (len=4 cap=4) {
(*ast.Ident)(0xc0000ca160)(error),
(*ast.Ident)(0xc0000ca180)(errors),
(*ast.Ident)(0xc0000ca260)(testing),
(*ast.Ident)(0xc0000ca320)(nil)
},
Comments: ([]*ast.CommentGroup) <nil>
})
package main
import (
"errors"
"testing"
)
func ThisMightReturnError() error {
return errors.New("I'm an error")
}
func TestMain(t *testing.T) {
err := ThisMightReturnError()
require.NotNil(t, "my error: %s", err)
}
package main
import (
"go/ast"
"go/parser"
"go/printer"
"go/token"
"log"
"os"
"github.com/davecgh/go-spew/spew"
"golang.org/x/tools/go/ast/astutil"
)
// isIfErrBlock detects folowing pattern:
// (*ast.IfStmt)(0xc0000241c0)({
// If: (token.Pos) 80,
// Init: (ast.Stmt) <nil>,
// Cond: (*ast.BinaryExpr)(0xc000074570)({
// X: (*ast.Ident)(0xc00005e220)(err),
// OpPos: (token.Pos) 87,
// Op: (token.Token) !=,
// Y: (*ast.Ident)(0xc00005e240)(nil)
// }),
func isIfErrBlock(n ast.Node) bool {
// is: if (err != nil)
if ifStmt, ok := n.(*ast.IfStmt); ok {
// is: (err != nil)
if binExpr, ok := ifStmt.Cond.(*ast.BinaryExpr); ok {
// is: !=
if binExpr.Op != token.NEQ {
return false
}
// Check left hand identifier (err)
if ident, ok := binExpr.X.(*ast.Ident); ok {
if ident.Obj == nil {
return false
}
if ident.Obj.Kind != ast.Var || ident.Name != "err" {
return false
}
}
// Check right hand identifier (nil)
if ident, ok := binExpr.Y.(*ast.Ident); ok {
if ident.Obj != nil || ident.Name != "nil" {
return false
}
}
return true
}
}
return false
}
func isErrBody(n ast.Node) (bool, []ast.Expr) {
// check if it's really an if statement
ifStmt, ok := n.(*ast.IfStmt)
if !ok {
return false, nil
}
// check if it has only one item in the block
if len(ifStmt.Body.List) != 1 {
return false, nil
}
stmt := ifStmt.Body.List[0]
// Cast expression to get the call
expr, ok := stmt.(*ast.ExprStmt)
if !ok {
return false, nil
}
call, ok := expr.X.(*ast.CallExpr)
if !ok {
return false, nil
}
// Get the function
fun, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return false, nil
}
if fun.Sel.Name != "Fatalf" {
return false, nil
}
return true, call.Args
}
func main() {
// Parse file
fs := token.NewFileSet()
f, err := parser.ParseFile(fs, "main_test.go", nil, parser.AllErrors)
if err != nil {
log.Fatal(err)
}
// Dump data
spew.Dump(f)
// Walk through AST
astutil.Apply(f, func(cr *astutil.Cursor) bool {
// Check if it is: if err != nil
if !isIfErrBlock(cr.Node()) {
return true
}
// Check if body contains only one statement and get args for it
sigRet, args := isErrBody(cr.Node())
if !sigRet {
return true
}
// Replace values
cr.Replace(&ast.ExprStmt{
X: &ast.CallExpr{
Fun: ast.NewIdent("require.NotNil"),
Args: append([]ast.Expr{ast.NewIdent("t")}, args...),
},
})
return false
}, nil)
// Print result
printer.Fprint(os.Stdout, fs, f)
}
@venjiang
Copy link

👍

@aditya-deepak1
Copy link

👍

@aditya-deepak1
Copy link

aditya-deepak1 commented Oct 8, 2022

@djboris9 is there a way by which I can entirely change the file structure of an existing go code for better structuring, similar to this: https://github.com/golang-standards/project-layout ? Idea here is to transform the old non-standard code into a standard one. We want to do this for all the repositories owned by us. Any leads would be really helpful.

@djboris9
Copy link
Author

djboris9 commented Oct 9, 2022

@aditya-deepak1 This is indeed possible with this method and is similar to this example but it requires a lot more logic to extract the parts, rewrite them to the structure you want and then serialize it again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment