Skip to content

Instantly share code, notes, and snippets.

@josharian josharian/np2n.go
Created Mar 21, 2016

Embed
What would you like to do?
np2n.go
package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"log"
"os"
"strings"
"github.com/kr/fs"
)
type visitor struct {
fset *token.FileSet
fn []string
changed bool
}
func (v *visitor) fnmatch(s string) bool {
for _, fn := range v.fn {
if fn == s {
return true
}
}
return false
}
func (v *visitor) gofmt(n ast.Node) string {
var buf bytes.Buffer
if err := format.Node(&buf, v.fset, n); err != nil {
log.Fatal(err)
}
return buf.String()
}
// convert
// fn(&n, ...)
// to
// n = fn(n, ...)
func (v *visitor) rewriteAddr(s ast.Stmt, call *ast.CallExpr) (ast.Stmt, bool) {
a0, ok := call.Args[0].(*ast.UnaryExpr)
if !ok || a0.Op != token.AND {
return nil, false
}
a0p := a0.X // strip &
call.Args[0] = a0p
lhs := ast.NewIdent(v.gofmt(a0p)) // cheat a bit
lhs.NamePos = s.Pos() // attempt to preserve comment location
as := &ast.AssignStmt{
Lhs: []ast.Expr{lhs},
Tok: token.ASSIGN,
Rhs: []ast.Expr{call},
}
return as, true
}
// convert
// fn(n.Addr(i), ...)
// to
// n.SetIndex(i, fn(n.Index(i), ...))
func (v *visitor) rewriteIndex(s ast.Stmt, call *ast.CallExpr) (ast.Stmt, bool) {
call2, ok := call.Args[0].(*ast.CallExpr)
if !ok {
return nil, false
}
if len(call2.Args) != 1 {
return nil, false
}
// fmt.Printf("%T %v\n", call2.Fun, v.gofmt(call2.Fun))
sel, ok := call2.Fun.(*ast.SelectorExpr)
if !ok {
return nil, false
}
if sel.Sel.Name != "Addr" {
return nil, false
}
sel.Sel.Name = "Index"
c := &ast.ExprStmt{
X: &ast.CallExpr{
Fun: ast.NewIdent(v.gofmt(sel.X) + ".SetIndex"),
Args: []ast.Expr{
call2.Args[0],
call,
},
},
}
return c, true
}
func (v *visitor) rewrite(s ast.Stmt, call *ast.CallExpr) (ast.Stmt, bool) {
t, ok := v.rewriteAddr(s, call)
if ok {
return t, ok
}
t, ok = v.rewriteIndex(s, call)
if ok {
return t, ok
}
return nil, false
}
func (v *visitor) Visit(n ast.Node) (w ast.Visitor) {
var list []ast.Stmt
var lst *ast.LabeledStmt
switch n := n.(type) {
case *ast.BlockStmt:
list = n.List
case *ast.CaseClause:
list = n.Body
case *ast.LabeledStmt:
list = []ast.Stmt{n.Stmt}
lst = n
default:
// fmt.Printf("stmt type %T at %s\n", n, v.fset.Position(n.Pos()))
return v
}
for i, s := range list {
es, ok := s.(*ast.ExprStmt)
if !ok {
continue
}
call, ok := es.X.(*ast.CallExpr)
if !ok {
continue
}
switch fn := call.Fun.(type) {
case *ast.Ident:
if !v.fnmatch(fn.Name) {
continue
}
default:
continue
}
if len(call.Args) == 0 {
fmt.Printf("missing args in %s at %s\n", v.gofmt(call), v.fset.Position(call.Pos()))
continue
}
t, ok := v.rewrite(s, call)
if !ok {
fmt.Printf("could not rewrite %s at %s\n", v.gofmt(call), v.fset.Position(call.Pos()))
continue
}
list[i] = t
if lst != nil {
lst.Stmt = t
}
v.changed = true
}
return v
}
func main() {
flag.Parse()
walker := fs.Walk(os.Args[1])
v := &visitor{
fn: os.Args[2:],
}
for walker.Step() {
if err := walker.Err(); err != nil {
log.Fatal(err)
}
if walker.Stat().IsDir() || !strings.HasSuffix(walker.Path(), ".go") {
continue
}
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, walker.Path(), nil, parser.ParseComments)
if err != nil {
log.Fatal(err)
}
v.fset = fset
v.changed = false
ast.Walk(v, f)
if !v.changed {
continue
}
w, err := os.OpenFile(walker.Path(), os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
log.Fatal(err)
}
if err := format.Node(w, v.fset, f); err != nil {
log.Fatal(err)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.