Skip to content

Instantly share code, notes, and snippets.

@fatih

fatih/copyast.go

Created Jul 13, 2019
Embed
What would you like to do?
Fork of wire/internal/copyast with support for *ast.File
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package copyast
import (
"fmt"
"go/ast"
"golang.org/x/tools/go/ast/astutil"
)
// File performs a deep copy of an *ast.File
func File(file *ast.File) *ast.File {
nf := *file
if file.Doc != nil {
nf.Doc = Node(file.Doc).(*ast.CommentGroup)
}
if nf.Name != nil {
nf.Name = Node(file.Name).(*ast.Ident)
}
nf.Package = file.Package
nf.Decls = copyDeclList(file.Decls)
nf.Imports = copyImportSpecs(file.Imports)
nf.Unresolved = file.Unresolved // keep untouched
nf.Comments = copyCommentGroups(file.Comments)
return &nf
}
func copyDeclList(x []ast.Decl) []ast.Decl {
if x == nil || len(x) == 0 {
return x
}
decls := make([]ast.Decl, len(x))
for i, decl := range x {
decls[i] = Node(decl).(ast.Decl)
}
return decls
}
func copyImportSpecs(x []*ast.ImportSpec) []*ast.ImportSpec {
if x == nil || len(x) == 0 {
return x
}
imports := make([]*ast.ImportSpec, len(x))
for i, imp := range x {
imports[i] = Node(imp).(*ast.ImportSpec)
}
return imports
}
func copyCommentGroups(x []*ast.CommentGroup) []*ast.CommentGroup {
if x == nil || len(x) == 0 {
return x
}
comments := make([]*ast.CommentGroup, len(x))
for i, comment := range x {
comments[i] = Node(comment).(*ast.CommentGroup)
}
return comments
}
// copyAST performs a deep copy of an AST. *ast.Ident identity will be
// preserved.
//
// This allows using astutil.Apply to rewrite an AST without modifying
// the original AST.
func Node(original ast.Node) ast.Node {
// This function is necessarily long. No utility function exists to do this
// clone, as most any attempt would need to have customization options, which
// would need to be as expressive as Apply. A possibility to shorten the code
// here would be to use reflection, but that trades clarity for shorter code.
m := make(map[ast.Node]ast.Node)
astutil.Apply(original, nil, func(c *astutil.Cursor) bool {
switch node := c.Node().(type) {
case nil:
// No-op.
case *ast.ArrayType:
m[node] = &ast.ArrayType{
Lbrack: node.Lbrack,
Len: exprFromMap(m, node.Len),
Elt: exprFromMap(m, node.Elt),
}
case *ast.AssignStmt:
m[node] = &ast.AssignStmt{
Lhs: copyExprList(m, node.Lhs),
TokPos: node.TokPos,
Tok: node.Tok,
Rhs: copyExprList(m, node.Rhs),
}
case *ast.BadDecl:
m[node] = &ast.BadDecl{
From: node.From,
To: node.To,
}
case *ast.BadExpr:
m[node] = &ast.BadExpr{
From: node.From,
To: node.To,
}
case *ast.BadStmt:
m[node] = &ast.BadStmt{
From: node.From,
To: node.To,
}
case *ast.BasicLit:
m[node] = &ast.BasicLit{
ValuePos: node.ValuePos,
Kind: node.Kind,
Value: node.Value,
}
case *ast.BinaryExpr:
m[node] = &ast.BinaryExpr{
X: exprFromMap(m, node.X),
OpPos: node.OpPos,
Op: node.Op,
Y: exprFromMap(m, node.Y),
}
case *ast.BlockStmt:
m[node] = &ast.BlockStmt{
Lbrace: node.Lbrace,
List: copyStmtList(m, node.List),
Rbrace: node.Rbrace,
}
case *ast.BranchStmt:
m[node] = &ast.BranchStmt{
TokPos: node.TokPos,
Tok: node.Tok,
Label: identFromMap(m, node.Label),
}
case *ast.CallExpr:
m[node] = &ast.CallExpr{
Fun: exprFromMap(m, node.Fun),
Lparen: node.Lparen,
Args: copyExprList(m, node.Args),
Ellipsis: node.Ellipsis,
Rparen: node.Rparen,
}
case *ast.CaseClause:
m[node] = &ast.CaseClause{
Case: node.Case,
List: copyExprList(m, node.List),
Colon: node.Colon,
Body: copyStmtList(m, node.Body),
}
case *ast.ChanType:
m[node] = &ast.ChanType{
Begin: node.Begin,
Arrow: node.Arrow,
Dir: node.Dir,
Value: exprFromMap(m, node.Value),
}
case *ast.CommClause:
m[node] = &ast.CommClause{
Case: node.Case,
Comm: stmtFromMap(m, node.Comm),
Colon: node.Colon,
Body: copyStmtList(m, node.Body),
}
case *ast.Comment:
m[node] = &ast.Comment{
Slash: node.Slash,
Text: node.Text,
}
case *ast.CommentGroup:
if node.List != nil {
cg := new(ast.CommentGroup)
cg.List = make([]*ast.Comment, len(node.List))
for i, nd := range node.List {
cg.List[i] = &ast.Comment{
Slash: nd.Slash,
Text: nd.Text,
}
}
m[node] = cg
}
case *ast.CompositeLit:
m[node] = &ast.CompositeLit{
Type: exprFromMap(m, node.Type),
Lbrace: node.Lbrace,
Elts: copyExprList(m, node.Elts),
Rbrace: node.Rbrace,
}
case *ast.DeclStmt:
m[node] = &ast.DeclStmt{
Decl: m[node.Decl].(ast.Decl),
}
case *ast.DeferStmt:
m[node] = &ast.DeferStmt{
Defer: node.Defer,
Call: callExprFromMap(m, node.Call),
}
case *ast.Ellipsis:
m[node] = &ast.Ellipsis{
Ellipsis: node.Ellipsis,
Elt: exprFromMap(m, node.Elt),
}
case *ast.EmptyStmt:
m[node] = &ast.EmptyStmt{
Semicolon: node.Semicolon,
Implicit: node.Implicit,
}
case *ast.ExprStmt:
m[node] = &ast.ExprStmt{
X: exprFromMap(m, node.X),
}
case *ast.Field:
m[node] = &ast.Field{
Doc: commentGroupFromMap(m, node.Doc),
Names: copyIdentList(m, node.Names),
Type: exprFromMap(m, node.Type),
Tag: basicLitFromMap(m, node.Tag),
Comment: commentGroupFromMap(m, node.Comment),
}
case *ast.FieldList:
fl := &ast.FieldList{
Opening: node.Opening,
Closing: node.Closing,
}
if node.List != nil {
fl.List = make([]*ast.Field, len(node.List))
for i := range node.List {
fl.List[i] = m[node.List[i]].(*ast.Field)
}
}
m[node] = fl
case *ast.ForStmt:
m[node] = &ast.ForStmt{
For: node.For,
Init: stmtFromMap(m, node.Init),
Cond: exprFromMap(m, node.Cond),
Post: stmtFromMap(m, node.Post),
Body: blockStmtFromMap(m, node.Body),
}
case *ast.FuncDecl:
m[node] = &ast.FuncDecl{
Doc: commentGroupFromMap(m, node.Doc),
Recv: fieldListFromMap(m, node.Recv),
Name: identFromMap(m, node.Name),
Type: funcTypeFromMap(m, node.Type),
Body: blockStmtFromMap(m, node.Body),
}
case *ast.FuncLit:
m[node] = &ast.FuncLit{
Type: funcTypeFromMap(m, node.Type),
Body: blockStmtFromMap(m, node.Body),
}
case *ast.FuncType:
m[node] = &ast.FuncType{
Func: node.Func,
Params: fieldListFromMap(m, node.Params),
Results: fieldListFromMap(m, node.Results),
}
case *ast.GenDecl:
decl := &ast.GenDecl{
Doc: commentGroupFromMap(m, node.Doc),
TokPos: node.TokPos,
Tok: node.Tok,
Lparen: node.Lparen,
Rparen: node.Rparen,
}
if node.Specs != nil {
decl.Specs = make([]ast.Spec, len(node.Specs))
for i := range node.Specs {
decl.Specs[i] = m[node.Specs[i]].(ast.Spec)
}
}
m[node] = decl
case *ast.GoStmt:
m[node] = &ast.GoStmt{
Go: node.Go,
Call: callExprFromMap(m, node.Call),
}
case *ast.Ident:
// Keep identifiers the same identity so they can be conveniently
// used with the original *types.Info.
fmt.Printf("node = %+v\n", node)
fmt.Printf("node.Object = %+v\n", node.Obj)
m[node] = node
case *ast.IfStmt:
m[node] = &ast.IfStmt{
If: node.If,
Init: stmtFromMap(m, node.Init),
Cond: exprFromMap(m, node.Cond),
Body: blockStmtFromMap(m, node.Body),
Else: stmtFromMap(m, node.Else),
}
case *ast.ImportSpec:
m[node] = &ast.ImportSpec{
Doc: commentGroupFromMap(m, node.Doc),
Name: identFromMap(m, node.Name),
Path: basicLitFromMap(m, node.Path),
Comment: commentGroupFromMap(m, node.Comment),
EndPos: node.EndPos,
}
case *ast.IncDecStmt:
m[node] = &ast.IncDecStmt{
X: exprFromMap(m, node.X),
TokPos: node.TokPos,
Tok: node.Tok,
}
case *ast.IndexExpr:
m[node] = &ast.IndexExpr{
X: exprFromMap(m, node.X),
Lbrack: node.Lbrack,
Index: exprFromMap(m, node.Index),
Rbrack: node.Rbrack,
}
case *ast.InterfaceType:
m[node] = &ast.InterfaceType{
Interface: node.Interface,
Methods: fieldListFromMap(m, node.Methods),
Incomplete: node.Incomplete,
}
case *ast.KeyValueExpr:
m[node] = &ast.KeyValueExpr{
Key: exprFromMap(m, node.Key),
Colon: node.Colon,
Value: exprFromMap(m, node.Value),
}
case *ast.LabeledStmt:
m[node] = &ast.LabeledStmt{
Label: identFromMap(m, node.Label),
Colon: node.Colon,
Stmt: stmtFromMap(m, node.Stmt),
}
case *ast.MapType:
m[node] = &ast.MapType{
Map: node.Map,
Key: exprFromMap(m, node.Key),
Value: exprFromMap(m, node.Value),
}
case *ast.ParenExpr:
m[node] = &ast.ParenExpr{
Lparen: node.Lparen,
X: exprFromMap(m, node.X),
Rparen: node.Rparen,
}
case *ast.RangeStmt:
m[node] = &ast.RangeStmt{
For: node.For,
Key: exprFromMap(m, node.Key),
Value: exprFromMap(m, node.Value),
TokPos: node.TokPos,
Tok: node.Tok,
X: exprFromMap(m, node.X),
Body: blockStmtFromMap(m, node.Body),
}
case *ast.ReturnStmt:
m[node] = &ast.ReturnStmt{
Return: node.Return,
Results: copyExprList(m, node.Results),
}
case *ast.SelectStmt:
m[node] = &ast.SelectStmt{
Select: node.Select,
Body: blockStmtFromMap(m, node.Body),
}
case *ast.SelectorExpr:
m[node] = &ast.SelectorExpr{
X: exprFromMap(m, node.X),
Sel: identFromMap(m, node.Sel),
}
case *ast.SendStmt:
m[node] = &ast.SendStmt{
Chan: exprFromMap(m, node.Chan),
Arrow: node.Arrow,
Value: exprFromMap(m, node.Value),
}
case *ast.SliceExpr:
m[node] = &ast.SliceExpr{
X: exprFromMap(m, node.X),
Lbrack: node.Lbrack,
Low: exprFromMap(m, node.Low),
High: exprFromMap(m, node.High),
Max: exprFromMap(m, node.Max),
Slice3: node.Slice3,
Rbrack: node.Rbrack,
}
case *ast.StarExpr:
m[node] = &ast.StarExpr{
Star: node.Star,
X: exprFromMap(m, node.X),
}
case *ast.StructType:
m[node] = &ast.StructType{
Struct: node.Struct,
Fields: fieldListFromMap(m, node.Fields),
Incomplete: node.Incomplete,
}
case *ast.SwitchStmt:
m[node] = &ast.SwitchStmt{
Switch: node.Switch,
Init: stmtFromMap(m, node.Init),
Tag: exprFromMap(m, node.Tag),
Body: blockStmtFromMap(m, node.Body),
}
case *ast.TypeAssertExpr:
m[node] = &ast.TypeAssertExpr{
X: exprFromMap(m, node.X),
Lparen: node.Lparen,
Type: exprFromMap(m, node.Type),
Rparen: node.Rparen,
}
case *ast.TypeSpec:
m[node] = &ast.TypeSpec{
Doc: commentGroupFromMap(m, node.Doc),
Name: identFromMap(m, node.Name),
Assign: node.Assign,
Type: exprFromMap(m, node.Type),
Comment: commentGroupFromMap(m, node.Comment),
}
case *ast.TypeSwitchStmt:
m[node] = &ast.TypeSwitchStmt{
Switch: node.Switch,
Init: stmtFromMap(m, node.Init),
Assign: stmtFromMap(m, node.Assign),
Body: blockStmtFromMap(m, node.Body),
}
case *ast.UnaryExpr:
m[node] = &ast.UnaryExpr{
OpPos: node.OpPos,
Op: node.Op,
X: exprFromMap(m, node.X),
}
case *ast.ValueSpec:
fmt.Printf("valueSpec = %+v\n", node)
m[node] = &ast.ValueSpec{
Doc: commentGroupFromMap(m, node.Doc),
Names: copyIdentList(m, node.Names),
Type: exprFromMap(m, node.Type),
Values: copyExprList(m, node.Values),
Comment: commentGroupFromMap(m, node.Comment),
}
default:
panic(fmt.Sprintf("unhandled AST node: %T", node))
}
return true
})
return m[original]
}
func commentGroupFromMap(m map[ast.Node]ast.Node, key *ast.CommentGroup) *ast.CommentGroup {
if key == nil {
return nil
}
return m[key].(*ast.CommentGroup)
}
func exprFromMap(m map[ast.Node]ast.Node, key ast.Expr) ast.Expr {
if key == nil {
return nil
}
return m[key].(ast.Expr)
}
func stmtFromMap(m map[ast.Node]ast.Node, key ast.Stmt) ast.Stmt {
if key == nil {
return nil
}
return m[key].(ast.Stmt)
}
func identFromMap(m map[ast.Node]ast.Node, key *ast.Ident) *ast.Ident {
if key == nil {
return nil
}
return m[key].(*ast.Ident)
}
func blockStmtFromMap(m map[ast.Node]ast.Node, key *ast.BlockStmt) *ast.BlockStmt {
if key == nil {
return nil
}
return m[key].(*ast.BlockStmt)
}
func fieldListFromMap(m map[ast.Node]ast.Node, key *ast.FieldList) *ast.FieldList {
if key == nil {
return nil
}
return m[key].(*ast.FieldList)
}
func callExprFromMap(m map[ast.Node]ast.Node, key *ast.CallExpr) *ast.CallExpr {
if key == nil {
return nil
}
return m[key].(*ast.CallExpr)
}
func basicLitFromMap(m map[ast.Node]ast.Node, key *ast.BasicLit) *ast.BasicLit {
if key == nil {
return nil
}
return m[key].(*ast.BasicLit)
}
func funcTypeFromMap(m map[ast.Node]ast.Node, key *ast.FuncType) *ast.FuncType {
if key == nil {
return nil
}
return m[key].(*ast.FuncType)
}
func copyExprList(m map[ast.Node]ast.Node, exprs []ast.Expr) []ast.Expr {
if exprs == nil {
return nil
}
newExprs := make([]ast.Expr, len(exprs))
for i := range exprs {
newExprs[i] = m[exprs[i]].(ast.Expr)
}
return newExprs
}
func copyStmtList(m map[ast.Node]ast.Node, stmts []ast.Stmt) []ast.Stmt {
if stmts == nil {
return nil
}
newStmts := make([]ast.Stmt, len(stmts))
for i := range stmts {
newStmts[i] = m[stmts[i]].(ast.Stmt)
}
return newStmts
}
func copyIdentList(m map[ast.Node]ast.Node, idents []*ast.Ident) []*ast.Ident {
if idents == nil {
return nil
}
newIdents := make([]*ast.Ident, len(idents))
for i := range idents {
newIdents[i] = m[idents[i]].(*ast.Ident)
}
return newIdents
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment