Created
February 9, 2019 15:50
-
-
Save yanniszark/0968e0ace0efc2f0307abefd69c8edb4 to your computer and use it in GitHub Desktop.
netrl_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package netrl_test | |
import ( | |
"context" | |
"errors" | |
"flag" | |
"io" | |
"log" | |
"net" | |
"sync" | |
"testing" | |
"time" | |
"github.com/yanniszark/scylla-code-assignment/netrl" | |
) | |
const ( | |
kb = 1024 | |
mb = 1024 * kb | |
) | |
var ( | |
count = flag.Int("clients", 5, "Number of clients") | |
limit = flag.Int("limit", 5*256*kb, "Limit bandwidth (bytes per second)") | |
limitConn = flag.Int("limit-conn", 256*kb, "Limit per connection bandwidth (bytes per second)") | |
duration = flag.Duration("time", 30*time.Second, "Test duration") | |
epsilon = flag.Float64("epsilon", 0.05, "How close") | |
typ = flag.String("type", "net", "Listener type.") | |
) | |
type byteReader byte | |
func (br byteReader) Read(p []byte) (int, error) { | |
for i := range p { | |
p[i] = byte(br) | |
} | |
return len(p), nil | |
} | |
type timedWriter struct { | |
first time.Time | |
last time.Time | |
} | |
func (tw *timedWriter) Write(p []byte) (int, error) { | |
if tw.first.IsZero() { | |
tw.first = time.Now() | |
} | |
tw.last = time.Now() | |
if tw.duration() >= *duration { | |
return len(p), errors.New("stop") | |
} | |
return len(p), nil | |
} | |
func (tw *timedWriter) duration() time.Duration { | |
return tw.last.Sub(tw.first) | |
} | |
type Listener struct { | |
net.Listener | |
AcceptFunc func() (net.Conn, error) | |
} | |
func (f Listener) Accept() (net.Conn, error) { | |
return f.AcceptFunc() | |
} | |
func yannisListener(ctx context.Context, network, addr string, limit, connLimit int) (net.Listener, error) { | |
l, err := netrl.Listen(network, addr) | |
if err != nil { | |
return nil, err | |
} | |
l.SetGlobalLimit(float64(limit)) | |
return Listener{ | |
Listener: l, | |
AcceptFunc: func() (net.Conn, error) { | |
conn, err := l.Accept() | |
if err != nil { | |
return nil, err | |
} | |
if err := netrl.SetLimitForConn(conn, float64(connLimit)); err != nil { | |
return nil, err | |
} | |
return conn, nil | |
}, | |
}, nil | |
} | |
func netListener(ctx context.Context, network, addr string, limit, connLimit int) (net.Listener, error) { | |
l, err := net.Listen(network, addr) | |
if err != nil { | |
return nil, err | |
} | |
return Listener{ | |
Listener: l, | |
AcceptFunc: l.Accept, | |
}, nil | |
} | |
func TestListener(t *testing.T) { | |
flag.Parse() | |
var listen func(context.Context, string, string, int, int) (net.Listener, error) | |
switch *typ { | |
case "yannis": | |
listen = yannisListener | |
case "net": | |
listen = netListener | |
default: | |
t.Fatalf("unsupported listener type: %s", *typ) | |
} | |
l, err := listen(context.Background(), "tcp", "127.0.0.1:0", *limit, *limitConn) | |
if err != nil { | |
t.Fatalf("listen()=%s", err) | |
} | |
defer l.Close() | |
go serve(l) | |
time.Sleep(2 * time.Second) | |
d := &net.Dialer{ | |
Timeout: 2 * time.Second, | |
KeepAlive: 5 * time.Second, | |
} | |
w := make([]timedWriter, *count) | |
n := make([]int64, *count) | |
var wg sync.WaitGroup | |
for i := range w { | |
wg.Add(1) | |
go func(i int) { | |
defer wg.Done() | |
c, err := d.Dial(l.Addr().Network(), l.Addr().String()) | |
if err != nil { | |
t.Fatalf("Dial()=%s", err) | |
} | |
defer c.Close() | |
n[i], _ = io.Copy(&w[i], c) | |
}(i) | |
} | |
want := int(duration.Seconds()+0.5) * minNonZero(*limitConn, *limit/len(w)) | |
eps := int(float64(want)*(*epsilon) + 0.5) | |
log.Printf("clients: %d", *count) | |
log.Printf("global limit: %d [kB/s], per connection: %d [kB/s]", *limit/kb, *limitConn/kb) | |
log.Printf("transfer duration: %s", *duration) | |
log.Printf("expected bandwidth within range (%d, %d) [kB] (epsilon=%.2f)", want-eps, want+eps, *epsilon) | |
wg.Wait() | |
_ = l.Close() | |
for i := range w { | |
got := int(n[i]) | |
if got < want-eps || got > want+eps { | |
t.Errorf("got %d not within range (%d, %d)", got, want-eps, want+eps) | |
} | |
} | |
} | |
func serve(l net.Listener) { | |
var null byteReader | |
for { | |
c, err := l.Accept() | |
if err != nil { | |
log.Fatalf("Error accepting conn: %+v", err) | |
} | |
go io.Copy(c, null) | |
} | |
} | |
func minNonZero(i, j int) int { | |
if i < j && i != 0 { | |
return i | |
} | |
return j | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment