Skip to content

Instantly share code, notes, and snippets.

@mpoindexter
Created March 14, 2023 00:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mpoindexter/f6bc9dac16290343efba17129c49f9d5 to your computer and use it in GitHub Desktop.
Save mpoindexter/f6bc9dac16290343efba17129c49f9d5 to your computer and use it in GitHub Desktop.
net/http: setting a timeout on http requests that use TLS can result in excessive requests to server
package main
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"os"
"os/signal"
"sync"
"time"
)
func main() {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
cert, err := createCert()
if err != nil {
log.Fatal(err)
}
pool := x509.NewCertPool()
pool.AddCert(cert.Leaf)
addr, err := server(ctx, cert, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount := r.Context().Value(requestCountType{}).(*requestCounter)
reqCount.count++
if reqCount.count == 100 {
w.Header().Set("connection", "close")
}
time.Sleep(60 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
if err != nil {
log.Fatal(err)
}
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
RootCAs: pool,
}
transport.MaxIdleConns = 100
transport.MaxIdleConnsPerHost = 100
transport.ForceAttemptHTTP2 = false
client := &http.Client{
//Timeout: 20 * time.Second,
Transport: transport,
}
timing := make(chan time.Duration, 100)
errs := make(chan error, 100)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for ctx.Err() == nil {
start := time.Now()
func() {
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
if err != nil {
errs <- err
return
}
resp, err := client.Do(req)
if err != nil {
errs <- err
return
}
defer resp.Body.Close()
io.Copy(io.Discard, resp.Body)
}()
timing <- time.Now().Sub(start)
}
}()
}
go func() {
countLastSec := 0
countTotal := 0
errsLastSec := 0
timeLastSec := time.Duration(0)
timer := time.NewTicker(1 * time.Second)
for {
select {
case <-timer.C:
if countLastSec != 0 {
fmt.Println(countLastSec, errsLastSec, countTotal, timeLastSec/time.Duration(countLastSec))
} else {
fmt.Println(countLastSec, errsLastSec, countTotal, 0)
}
countLastSec = 0
errsLastSec = 0
timeLastSec = 0
case m, ok := <-timing:
if ok {
countLastSec++
countTotal++
timeLastSec += m
} else {
return
}
case _, ok := <-errs:
if ok {
errsLastSec++
} else {
return
}
}
}
}()
<-ctx.Done()
wg.Wait()
close(timing)
close(errs)
}
func server(ctx context.Context, cert tls.Certificate, handler http.Handler) (string, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", err
}
connLimit := make(chan struct{}, 50)
for i := 0; i < 50; i++ {
connLimit <- struct{}{}
}
var server http.Server
server.TLSConfig = &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
<-connLimit
defer func() {
connLimit <- struct{}{}
}()
time.Sleep(100 * time.Millisecond)
return &cert, nil
},
}
server.Handler = handler
server.ErrorLog = log.New(io.Discard, "", 0)
server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
server.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, requestCountType{}, &requestCounter{})
}
go func() {
<-ctx.Done()
server.Close()
}()
go func() {
server.ServeTLS(listener, "", "")
}()
port := listener.Addr().(*net.TCPAddr).Port
return fmt.Sprintf("localhost:%v", port), nil
}
func createCert() (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{
"localhost",
},
BasicConstraintsValid: true,
}
cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
parsedCert, err := x509.ParseCertificate(cert)
if err != nil {
return tls.Certificate{}, err
}
return tls.Certificate{
Certificate: [][]byte{
cert,
},
PrivateKey: priv,
Leaf: parsedCert,
}, nil
}
type requestCountType struct{}
type requestCounter struct {
count int
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment