Skip to content

Instantly share code, notes, and snippets.

@mtsmfm
Created November 24, 2020 01:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mtsmfm/ae8b40c4a075f16de4081fa35363ce2e to your computer and use it in GitHub Desktop.
Save mtsmfm/ae8b40c4a075f16de4081fa35363ce2e to your computer and use it in GitHub Desktop.
rewrite.go
package main
import (
"go/ast"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
"os"
"golang.org/x/tools/go/ast/astutil"
)
type visitFn func(node ast.Node)
func (fn visitFn) Visit(node ast.Node) ast.Visitor {
fn(node)
return fn
}
func main() {
fset := token.NewFileSet()
filename := "memo/test.go"
src, err := ioutil.ReadFile(filename)
if err != nil {
panic(err)
}
file, err := parser.ParseFile(fset, filename, src, parser.AllErrors)
if err != nil {
panic(err)
}
var mostRecentDeclStmtPos token.Pos
type data struct {
Name string
SQL string
}
targets := make(map[token.Pos]data)
astutil.Apply(file, func(cr *astutil.Cursor) bool {
switch v := cr.Node().(type) {
case *ast.DeclStmt:
mostRecentDeclStmtPos = v.Pos()
break
case *ast.CallExpr:
selector, ok := v.Fun.(*ast.SelectorExpr)
if !ok {
break
}
ident, ok := selector.X.(*ast.Ident)
if !ok {
break
}
if ident.Name == "db" && selector.Sel.Name == "Select" {
if len(v.Args) < 2 {
break
}
desc, ok := v.Args[0].(*ast.UnaryExpr)
if !ok {
break
}
descIdent, ok := desc.X.(*ast.Ident)
if !ok {
break
}
sql, ok := v.Args[1].(*ast.BasicLit)
if !ok {
break
}
targets[mostRecentDeclStmtPos] = data{Name: descIdent.Name, SQL: sql.Value}
}
break
}
return true
}, nil)
n := astutil.Apply(file, func(cr *astutil.Cursor) bool {
switch v := cr.Node().(type) {
case *ast.DeclStmt:
if t, ok := targets[v.Pos()]; ok {
cr.Replace(&ast.DeclStmt{
Decl: &ast.GenDecl{
Specs: []ast.Spec{&ast.ValueSpec{
Type: &ast.ArrayType{
Elt: &ast.StructType{
Fields: &ast.FieldList{
List: []*ast.Field{},
},
},
},
Names: []*ast.Ident{
&ast.Ident{
Name: t.Name,
},
}}},
Tok: token.VAR,
},
})
}
break
}
return true
}, nil)
if err := format.Node(os.Stdout, token.NewFileSet(), n); err != nil {
log.Fatalln("Error:", err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment