-
-
Save mpoindexter/f6bc9dac16290343efba17129c49f9d5 to your computer and use it in GitHub Desktop.
net/http: setting a timeout on http requests that use TLS can result in excessive requests to server
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"crypto/ecdsa" | |
"crypto/elliptic" | |
"crypto/rand" | |
"crypto/tls" | |
"crypto/x509" | |
"crypto/x509/pkix" | |
"fmt" | |
"io" | |
"log" | |
"math/big" | |
"net" | |
"net/http" | |
"os" | |
"os/signal" | |
"sync" | |
"time" | |
) | |
func main() { | |
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) | |
defer cancel() | |
cert, err := createCert() | |
if err != nil { | |
log.Fatal(err) | |
} | |
pool := x509.NewCertPool() | |
pool.AddCert(cert.Leaf) | |
addr, err := server(ctx, cert, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
reqCount := r.Context().Value(requestCountType{}).(*requestCounter) | |
reqCount.count++ | |
if reqCount.count == 100 { | |
w.Header().Set("connection", "close") | |
} | |
time.Sleep(60 * time.Millisecond) | |
w.WriteHeader(http.StatusOK) | |
})) | |
if err != nil { | |
log.Fatal(err) | |
} | |
transport := http.DefaultTransport.(*http.Transport).Clone() | |
transport.TLSClientConfig = &tls.Config{ | |
RootCAs: pool, | |
} | |
transport.MaxIdleConns = 100 | |
transport.MaxIdleConnsPerHost = 100 | |
transport.ForceAttemptHTTP2 = false | |
client := &http.Client{ | |
//Timeout: 20 * time.Second, | |
Transport: transport, | |
} | |
timing := make(chan time.Duration, 100) | |
errs := make(chan error, 100) | |
var wg sync.WaitGroup | |
for i := 0; i < 100; i++ { | |
wg.Add(1) | |
go func() { | |
defer wg.Done() | |
for ctx.Err() == nil { | |
start := time.Now() | |
func() { | |
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil) | |
if err != nil { | |
errs <- err | |
return | |
} | |
resp, err := client.Do(req) | |
if err != nil { | |
errs <- err | |
return | |
} | |
defer resp.Body.Close() | |
io.Copy(io.Discard, resp.Body) | |
}() | |
timing <- time.Now().Sub(start) | |
} | |
}() | |
} | |
go func() { | |
countLastSec := 0 | |
countTotal := 0 | |
errsLastSec := 0 | |
timeLastSec := time.Duration(0) | |
timer := time.NewTicker(1 * time.Second) | |
for { | |
select { | |
case <-timer.C: | |
if countLastSec != 0 { | |
fmt.Println(countLastSec, errsLastSec, countTotal, timeLastSec/time.Duration(countLastSec)) | |
} else { | |
fmt.Println(countLastSec, errsLastSec, countTotal, 0) | |
} | |
countLastSec = 0 | |
errsLastSec = 0 | |
timeLastSec = 0 | |
case m, ok := <-timing: | |
if ok { | |
countLastSec++ | |
countTotal++ | |
timeLastSec += m | |
} else { | |
return | |
} | |
case _, ok := <-errs: | |
if ok { | |
errsLastSec++ | |
} else { | |
return | |
} | |
} | |
} | |
}() | |
<-ctx.Done() | |
wg.Wait() | |
close(timing) | |
close(errs) | |
} | |
func server(ctx context.Context, cert tls.Certificate, handler http.Handler) (string, error) { | |
listener, err := net.Listen("tcp", "127.0.0.1:0") | |
if err != nil { | |
return "", err | |
} | |
connLimit := make(chan struct{}, 50) | |
for i := 0; i < 50; i++ { | |
connLimit <- struct{}{} | |
} | |
var server http.Server | |
server.TLSConfig = &tls.Config{ | |
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { | |
<-connLimit | |
defer func() { | |
connLimit <- struct{}{} | |
}() | |
time.Sleep(100 * time.Millisecond) | |
return &cert, nil | |
}, | |
} | |
server.Handler = handler | |
server.ErrorLog = log.New(io.Discard, "", 0) | |
server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) | |
server.ConnContext = func(ctx context.Context, c net.Conn) context.Context { | |
return context.WithValue(ctx, requestCountType{}, &requestCounter{}) | |
} | |
go func() { | |
<-ctx.Done() | |
server.Close() | |
}() | |
go func() { | |
server.ServeTLS(listener, "", "") | |
}() | |
port := listener.Addr().(*net.TCPAddr).Port | |
return fmt.Sprintf("localhost:%v", port), nil | |
} | |
func createCert() (tls.Certificate, error) { | |
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) | |
if err != nil { | |
return tls.Certificate{}, err | |
} | |
template := x509.Certificate{ | |
SerialNumber: big.NewInt(1), | |
Subject: pkix.Name{ | |
Organization: []string{"Test"}, | |
}, | |
NotBefore: time.Now(), | |
NotAfter: time.Now().Add(time.Hour * 24 * 180), | |
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, | |
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, | |
DNSNames: []string{ | |
"localhost", | |
}, | |
BasicConstraintsValid: true, | |
} | |
cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) | |
if err != nil { | |
return tls.Certificate{}, err | |
} | |
parsedCert, err := x509.ParseCertificate(cert) | |
if err != nil { | |
return tls.Certificate{}, err | |
} | |
return tls.Certificate{ | |
Certificate: [][]byte{ | |
cert, | |
}, | |
PrivateKey: priv, | |
Leaf: parsedCert, | |
}, nil | |
} | |
type requestCountType struct{} | |
type requestCounter struct { | |
count int | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment