Skip to content

Instantly share code, notes, and snippets.

@mariash
Last active January 16, 2024 21:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mariash/cf75a2deff1d16af14ef8020393ccd48 to your computer and use it in GitHub Desktop.
Save mariash/cf75a2deff1d16af14ef8020393ccd48 to your computer and use it in GitHub Desktop.
Golang ReverseProxy race condition
package main
import (
"bytes"
"io"
"net/http"
"net/http/httptrace"
"net/http/httputil"
"net/url"
"time"
)
func main() {
readyCh := make(chan struct{})
backendServer := http.NewServeMux()
backendServer.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusContinue)
hj, _ := rw.(http.Hijacker)
conn, _, _ := hj.Hijack()
conn.Close()
})
go http.ListenAndServe(":8081", backendServer)
time.Sleep(1 * time.Second)
handler := func(p *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) {
return func(rw http.ResponseWriter, r *http.Request) {
r = r.WithContext(httptrace.WithClientTrace(r.Context(), &httptrace.ClientTrace{
Got100Continue: func() {
// Delay the 1xx hook
<-readyCh
},
}))
p.ServeHTTP(rw, r)
rw.Header().Set("X-Something", "Hello")
}
}
// trigger trace context once, blocking first time
go func() {
readyCh <- struct{}{}
}()
target, err := url.Parse("http://localhost:8081")
if err != nil {
panic(err)
}
proxy := httputil.NewSingleHostReverseProxy(target)
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
readyCh <- struct{}{}
for i := 0; i < 10000000; i++ {
rw.Header().Set("X-Something", "Hello")
}
}
proxy.Transport = &RetryTransport{T: http.DefaultTransport}
http.HandleFunc("/", handler(proxy))
go http.ListenAndServe(":8080", nil)
time.Sleep(1 * time.Second)
data := bytes.NewBufferString("Hello!")
req, err := http.NewRequest("POST", "http://localhost:8080", data)
if err != nil {
panic(err)
}
req.Header.Set("Expect", "100-continue")
resp, err := http.DefaultClient.Do(req)
if err != nil {
panic(err)
}
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
time.Sleep(10 * time.Second)
println("done")
}
type RetryTransport struct {
T http.RoundTripper
}
func (tr *RetryTransport) RoundTrip(req *http.Request) (res *http.Response, err error) {
for i := 0; i < 3; i++ {
res, err = tr.T.RoundTrip(req)
if err == nil {
return res, err
}
}
return res, err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment