Skip to content

Instantly share code, notes, and snippets.

@yanniszark
Created February 9, 2019 15:50
Show Gist options
  • Save yanniszark/0968e0ace0efc2f0307abefd69c8edb4 to your computer and use it in GitHub Desktop.
Save yanniszark/0968e0ace0efc2f0307abefd69c8edb4 to your computer and use it in GitHub Desktop.
netrl_test.go
package netrl_test
import (
"context"
"errors"
"flag"
"io"
"log"
"net"
"sync"
"testing"
"time"
"github.com/yanniszark/scylla-code-assignment/netrl"
)
const (
kb = 1024
mb = 1024 * kb
)
var (
count = flag.Int("clients", 5, "Number of clients")
limit = flag.Int("limit", 5*256*kb, "Limit bandwidth (bytes per second)")
limitConn = flag.Int("limit-conn", 256*kb, "Limit per connection bandwidth (bytes per second)")
duration = flag.Duration("time", 30*time.Second, "Test duration")
epsilon = flag.Float64("epsilon", 0.05, "How close")
typ = flag.String("type", "net", "Listener type.")
)
type byteReader byte
func (br byteReader) Read(p []byte) (int, error) {
for i := range p {
p[i] = byte(br)
}
return len(p), nil
}
type timedWriter struct {
first time.Time
last time.Time
}
func (tw *timedWriter) Write(p []byte) (int, error) {
if tw.first.IsZero() {
tw.first = time.Now()
}
tw.last = time.Now()
if tw.duration() >= *duration {
return len(p), errors.New("stop")
}
return len(p), nil
}
func (tw *timedWriter) duration() time.Duration {
return tw.last.Sub(tw.first)
}
type Listener struct {
net.Listener
AcceptFunc func() (net.Conn, error)
}
func (f Listener) Accept() (net.Conn, error) {
return f.AcceptFunc()
}
func yannisListener(ctx context.Context, network, addr string, limit, connLimit int) (net.Listener, error) {
l, err := netrl.Listen(network, addr)
if err != nil {
return nil, err
}
l.SetGlobalLimit(float64(limit))
return Listener{
Listener: l,
AcceptFunc: func() (net.Conn, error) {
conn, err := l.Accept()
if err != nil {
return nil, err
}
if err := netrl.SetLimitForConn(conn, float64(connLimit)); err != nil {
return nil, err
}
return conn, nil
},
}, nil
}
func netListener(ctx context.Context, network, addr string, limit, connLimit int) (net.Listener, error) {
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return Listener{
Listener: l,
AcceptFunc: l.Accept,
}, nil
}
func TestListener(t *testing.T) {
flag.Parse()
var listen func(context.Context, string, string, int, int) (net.Listener, error)
switch *typ {
case "yannis":
listen = yannisListener
case "net":
listen = netListener
default:
t.Fatalf("unsupported listener type: %s", *typ)
}
l, err := listen(context.Background(), "tcp", "127.0.0.1:0", *limit, *limitConn)
if err != nil {
t.Fatalf("listen()=%s", err)
}
defer l.Close()
go serve(l)
time.Sleep(2 * time.Second)
d := &net.Dialer{
Timeout: 2 * time.Second,
KeepAlive: 5 * time.Second,
}
w := make([]timedWriter, *count)
n := make([]int64, *count)
var wg sync.WaitGroup
for i := range w {
wg.Add(1)
go func(i int) {
defer wg.Done()
c, err := d.Dial(l.Addr().Network(), l.Addr().String())
if err != nil {
t.Fatalf("Dial()=%s", err)
}
defer c.Close()
n[i], _ = io.Copy(&w[i], c)
}(i)
}
want := int(duration.Seconds()+0.5) * minNonZero(*limitConn, *limit/len(w))
eps := int(float64(want)*(*epsilon) + 0.5)
log.Printf("clients: %d", *count)
log.Printf("global limit: %d [kB/s], per connection: %d [kB/s]", *limit/kb, *limitConn/kb)
log.Printf("transfer duration: %s", *duration)
log.Printf("expected bandwidth within range (%d, %d) [kB] (epsilon=%.2f)", want-eps, want+eps, *epsilon)
wg.Wait()
_ = l.Close()
for i := range w {
got := int(n[i])
if got < want-eps || got > want+eps {
t.Errorf("got %d not within range (%d, %d)", got, want-eps, want+eps)
}
}
}
func serve(l net.Listener) {
var null byteReader
for {
c, err := l.Accept()
if err != nil {
log.Fatalf("Error accepting conn: %+v", err)
}
go io.Copy(c, null)
}
}
func minNonZero(i, j int) int {
if i < j && i != 0 {
return i
}
return j
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment