Skip to content

Instantly share code, notes, and snippets.

@matthewmueller
Created August 31, 2020 09:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matthewmueller/d2eb103a2cd98abfcf6ecdfceb993374 to your computer and use it in GitHub Desktop.
Save matthewmueller/d2eb103a2cd98abfcf6ecdfceb993374 to your computer and use it in GitHub Desktop.
ResponseWriter wrapper to allows middleware to set headers without worrying about the response header being written already.
package responsewriter
import (
"bytes"
"net/http"
"github.com/felixge/httpsnoop"
)
// Wrap the response writer
func Wrap(w http.ResponseWriter) *ResponseWriter {
state := new(state)
responseWriter := httpsnoop.Wrap(w, httpsnoop.Hooks{
WriteHeader: func(_ httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return func(code int) {
state.Code = code
}
},
Write: func(_ httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(p []byte) (int, error) {
return state.Body.Write(p)
}
},
})
return &ResponseWriter{
responseWriter,
state,
w,
}
}
// Unwrap the response writer
func Unwrap(w http.ResponseWriter) (rw *ResponseWriter, ok bool) {
rw, ok = w.(*ResponseWriter)
return rw, ok
}
// state struct
type state struct {
Code int
Body bytes.Buffer
}
// ResponseWriter struct
type ResponseWriter struct {
http.ResponseWriter
*state
original http.ResponseWriter
}
// Flush the response
func (rw *ResponseWriter) Flush() (int, error) {
rw.original.(http.ResponseWriter).WriteHeader(rw.state.Code)
return rw.original.(http.ResponseWriter).Write(rw.state.Body.Bytes())
}
package responsewriter_test
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/matthewmueller/responsewriter"
"github.com/tj/assert"
)
func TestWrap(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := responsewriter.Wrap(w)
defer rw.Flush()
rw.WriteHeader(201)
rw.Write([]byte(`hi world`))
rw.Header().Add("Content-Type", "text/plain")
}))
defer s.Close()
res, err := http.Get(s.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, 201, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "hi world", string(body))
value := res.Header.Get("Content-Type")
assert.Equal(t, "text/plain", value)
}
func TestUnwrap(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := responsewriter.Wrap(w)
w = rw
w.WriteHeader(201)
w.Write([]byte(`hi world`))
w.Header().Add("Content-Type", "text/plain")
rw, ok := responsewriter.Unwrap(w)
assert.True(t, ok)
assert.Equal(t, 201, rw.Code)
assert.Equal(t, string(`hi world`), rw.Body.String())
}))
defer s.Close()
res, err := http.Get(s.URL)
assert.NoError(t, err)
defer res.Body.Close()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment