Skip to content

Instantly share code, notes, and snippets.

@Ehco1996
Created August 14, 2019 13:50
Show Gist options
  • Save Ehco1996/ab6caeac1a6bca1fa2138afebb9ff205 to your computer and use it in GitHub Desktop.
Save Ehco1996/ab6caeac1a6bca1fa2138afebb9ff205 to your computer and use it in GitHub Desktop.
过滤掉除了select之外的语句,并且将所有select的语句加上/修改成 `limit 1`
package main
import (
"fmt"
"log"
"strings"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/format"
_ "github.com/pingcap/tidb/types/parser_driver"
)
// Rewrite sql Rewrite
type Rewrite struct {
SQL string
NewSQL string
Stmt ast.StmtNode
}
// NewRewrite Func
func NewRewrite(sql, charset, collation string) *Rewrite {
p := parser.New()
stmtNode, err := p.ParseOneStmt(sql, charset, collation)
if err != nil {
log.Fatal("error...", err)
}
return &Rewrite{
SQL: sql,
Stmt: stmtNode,
}
}
func newLimit(val int) *ast.Limit {
limit := ast.Limit{
Count: ast.NewValueExpr(val),
}
return &limit
}
type checkLimitVisitor struct{}
func (clv *checkLimitVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch node := in.(type) {
case *ast.Limit:
count := ast.NewValueExpr(1)
node.Count = count
return node, false
case *ast.SelectStmt:
node.Limit = newLimit(1)
}
return in, true
}
func (clv *checkLimitVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}
func (rw *Rewrite) forceSelectLimit1() *Rewrite {
if rw.Stmt == nil {
return rw
}
foundSelect := false
switch stmt := rw.Stmt.(type) {
case *ast.SelectStmt:
v := checkLimitVisitor{}
stmt.Accept(&v)
foundSelect = true
}
if foundSelect {
var sb strings.Builder
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)
rw.Stmt.Restore(ctx)
rw.NewSQL = sb.String()
}
return rw
}
func main() {
sql1 := "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.fid WHERE t1.c>100 limit 100;"
rw := NewRewrite(sql1, "", "")
rw.forceSelectLimit1()
fmt.Println(rw.NewSQL)
// OUT: SELECT `t1`.`a`,`t2`.`b` FROM `t1` JOIN `t2` ON `t1`.`id`=`t2`.`fid` WHERE `t1`.`c`>100 LIMIT 1
sql2 := "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.fid WHERE t1.c>100;"
rw = NewRewrite(sql2, "", "")
rw.forceSelectLimit1()
fmt.Println(rw.NewSQL)
// OUT: SELECT `t1`.`a`,`t2`.`b` FROM `t1` JOIN `t2` ON `t1`.`id`=`t2`.`fid` WHERE `t1`.`c`>100 LIMIT 1
sql3 := "DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste';"
rw = NewRewrite(sql3, "", "")
rw.forceSelectLimit1()
fmt.Println(rw.NewSQL)
// OUT: ""
}
@ld140319
Copy link

如果需要将所有查询参数替换为? 需要怎么处理呢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment