public
Created

Read output from pt-query-digest and multiplex queries to MySQL over multiple threads.

  • Download Gist
mysql_query_multiplexer.go
Go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
package main
 
import (
"bufio"
"bytes"
_ "github.com/Go-SQL-Driver/MySQL"
"database/sql"
"flag"
"fmt"
"io"
"log"
"os"
"strings"
)
 
var (
// Actual messages
msgs = make(chan string, 1000000)
 
// Indicate when consumer has finishes
done = make(chan bool)
 
db_host string
db_user string
db_password string
db_name string
db_charset string
)
 
func produce() {
r := bufio.NewReaderSize(io.Reader(os.Stdin), 32*1024*1024)
var query bytes.Buffer
for {
line, err := r.ReadString('\n')
if err == nil {
// Based on pt-query-digest's output:
// Start of a new '#' block indicates the previous query has
// ended. This is slightly more robust than looking for a trailing
// semi-colon, in case the query doesn't end in one.
if len(line) > 0 && line[0] == '#' {
s := strings.TrimSpace(query.String())
if len(s) > 1 {
select {
case msgs <- s:
/* nothing */
default:
log.Println("Channel full; dropping message")
}
}
query.Reset()
} else {
query.WriteString(" " + line)
}
continue
}
if err != nil && err != io.EOF {
fmt.Println("Error reading from stdin: " + err.Error())
break
}
if err != nil && err == io.EOF {
break
}
}
fmt.Println("Done with produce")
close(msgs)
}
 
func consume() {
defer func() {
done <- true
}()
 
for {
select {
case msg, ok := <-msgs:
if !ok {
return
}
 
// Creating a new connection for every query. Trying to exercise MySQL's connection
// handling a bit
db, e := sql.Open("mysql", db_user+":"+db_password+"@tcp("+db_host+":3306)/"+db_name+"?charset="+db_charset)
if e != nil {
panic(e)
}
 
_, err := db.Exec(msg)
if err != nil {
log.Println(msg)
log.Println(err.Error())
}
db.Close()
}
}
 
}
 
func main() {
var threads = flag.Int("threads", 8, "Execution threads")
flag.StringVar(&db_host, "db-host", "", "DB Host")
flag.StringVar(&db_user, "db-user", "", "DB Username")
flag.StringVar(&db_password, "db-password", "", "DB Password")
flag.StringVar(&db_name, "db-name", "", "DB Name")
flag.StringVar(&db_charset, "db-charset", "", "DB Character set")
var error_log = flag.String("log", "", "Error log file location and name")
var help = flag.Bool("h", false, "Help")
flag.Parse()
 
if *help == true {
flag.Usage()
return
}
 
if db_host == "" || db_user == "" || db_password == "" || db_name == "" || db_charset == "" || *error_log == "" {
flag.Usage()
return
}
 
logFile, err := os.OpenFile(*error_log, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
panic(err.Error() + " opening log file: " + *error_log)
}
log.SetOutput(logFile)
log.SetFlags(log.LstdFlags)
 
fmt.Println(fmt.Sprintf("Starting go-execution with %d threads", *threads))
go produce()
for i := 0; i < *threads; i++ {
go consume()
}
 
for i := 0; i < *threads; i++ {
<-done
}
 
fmt.Println("Done with consume")
}

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.