Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
// This is a reduced test case for https://github.com/go-mgo/mgo/issues/254
//
// What it does is to overload DialServer() and return a custom net.Conn.
// Some of these custom connections timeout _once_ after the first 1000 bytes are read.
// It seems like DialServer() is called more times than the PoolLimit, which is
// set to 5, and not all these sockets are used, so for convenience, we pick
// every third connection to automatically time out once.
// It is still possible for this test to be slightly flaky, you may have to run it more than once.
//
// Remember to edit the Dial URL and the database name in insert to something valid.
//
// Running this test with `go run ./main.go` will cause ~11 sockets to be
// created. Some of these will time out, at which point, since the server's
// `abended` is not reset, every subsequent query, even through sockets that
// are _NOT_ failed, will keep performing the authentication cycle. This will
// lead to output like:
// Having to master check 0xc420014a20
// Having to master check 0xc4200da060
// Having to master check 0xc4200da4c0
// Having to master check 0xc420014bc0
// ...
//
// This output is from the Write() snooping we are doing and means that the
// client is verifying master on every socket, as described in #254.
package main
import (
"bytes"
"fmt"
"net"
"os"
"sync"
"time"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
)
type myConn struct {
net.Conn
bytesRead uint64
tripped bool
}
func (c *myConn) Read(b []byte) (n int, err error) {
if c.bytesRead == 0 && !c.tripped {
fmt.Printf("Going to eventually trip %p\n", c)
}
n, err = c.Conn.Read(b)
c.bytesRead += uint64(n)
if c.bytesRead > 1000 && !c.tripped {
fmt.Printf("Socket not tripped yet, timing out %p\n", c)
time.Sleep(5 * time.Second)
c.tripped = true
}
return n, err
}
func (c *myConn) Write(b []byte) (n int, err error) {
if bytes.Contains(b, []byte("ismaster")) {
fmt.Printf("Having to master check %p\n", c)
}
n, err = c.Conn.Write(b)
return n, err
}
type myLogger struct{}
func (m *myLogger) Output(calldepth int, s string) error {
fmt.Fprintln(os.Stderr, s)
return nil
}
func main() {
//mgo.SetLogger(&myLogger{})
//mgo.SetDebug(true)
info, err := mgo.ParseURL("mongodb://localhost:32768/mgo-trawl")
if err != nil {
fmt.Println("Error parsing url", err)
return
}
info.Timeout = 2 * time.Second
info.PoolLimit = 5
dialCalls := 0
var dialMutex sync.Mutex
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
dialMutex.Lock()
dialCalls++
d := dialCalls
dialMutex.Unlock()
conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
if err != nil {
return nil, err
}
c := &myConn{conn, 0, d%3 == 0 /* only a few of them trips */}
fmt.Printf("Established wrapped conn %p\n", c)
return c, nil
}
session, err := mgo.DialWithInfo(info)
if err != nil {
fmt.Println("Error establishing session", err)
return
}
var wg sync.WaitGroup
// ensure all sockets in pool get used.
for j := 11; j >= 0; j-- {
wg.Add(1)
go func(i int) {
for k := 10000; k >= 0; k-- {
s := session.Copy()
err = s.DB("mgo-trawl").C("test").Insert(bson.M{"hi": "there"})
if err != nil {
fmt.Println("Insert error!", k, err)
}
s.Close()
}
wg.Done()
}(j)
}
wg.Wait()
fmt.Println("All inserts done")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment