Skip to content

Instantly share code, notes, and snippets.

@wchan2
Last active May 8, 2017 08:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wchan2/d3ba859cde87a873396d4c3e19be58f6 to your computer and use it in GitHub Desktop.
Save wchan2/d3ba859cde87a873396d4c3e19be58f6 to your computer and use it in GitHub Desktop.
GoLang HTTP Middleware

GoLang HTTP Middleware

Utility functions that can compose http.Handlers and http.HandlerFuncs to create middlewares. Note that when the status code is written to the http.ResponseWriter, the response is immediately flushed regardless whether a handlers further down the middleware chain is writing to the http.ResponseWriter.

Use cases

Some use cases where this may be useful are listed below. Please feel free to comment to add to the list.

  • Requiring authentication before allowing access to the main handler
  • Proxy to different services depending on application logic or migrations
package http_middleware
import (
"net/http"
)
type Middleware struct {
middleware http.Handler
handler http.Handler
}
func (m *Middleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
m.middleware.ServeHTTP(w, req)
m.handler.ServeHTTP(w, req)
}
func NewMiddleware(middleware, handler http.Handler) *Middleware {
return &Middleware{
handler: handler,
middleware: middleware,
}
}
func ComposeMiddlewareHandler(middlewares ...http.Handler) func(http.Handler) http.Handler {
var middleware http.Handler
if len(middlewares) == 0 {
return nil
} else if len(middlewares) == 1 {
middleware = middlewares[0]
} else {
middleware = NewMiddleware(middlewares[0], middlewares[1])
for i := 2; i < len(middlewares); i++ {
middleware = NewMiddleware(middleware, middlewares[i])
}
}
return func(h http.Handler) http.Handler {
return NewMiddleware(middleware, h)
}
}
func ComposeMiddlewareHandlerFunc(middlewares ...http.HandlerFunc) func(http.HandlerFunc) http.HandlerFunc {
var middleware http.HandlerFunc
if len(middlewares) == 0 {
return nil
} else if len(middlewares) == 1 {
middleware = middlewares[0]
} else {
middleware = compose(middlewares[0], middlewares[1])
for i := 2; i < len(middlewares); i++ {
middleware = compose(middleware, middlewares[i])
}
}
return func(h http.HandlerFunc) http.HandlerFunc {
return compose(middleware, h)
}
}
func compose(middlewareA, middlewareB http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
middlewareA(w, req)
middlewareB(w, req)
}
}
package http_middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
. "github.com/wchan2/http_middleware"
)
var (
handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusOK)
})
middlewareA = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("hello world\n"))
})
middlewareB = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("good bye world\n"))
})
)
func TestComposeMiddlewareHandler(t *testing.T) {
handlerWithMiddleware := ComposeMiddlewareHandler(middlewareA, middlewareB)(handler)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatal("Could not create request")
}
recorder := httptest.NewRecorder()
handlerWithMiddleware.ServeHTTP(recorder, req)
if recorder.Body.String() != "hello world\ngood bye world\n" {
t.Errorf("middleware did not write response body correctly: %s", recorder.Body.String())
}
if recorder.Code != http.StatusOK {
t.Errorf("handler did not write status code correctly: %d", recorder.Code)
}
}
func TestMiddlewareHandlerFunc(t *testing.T) {
handlerWithMiddleware := ComposeMiddlewareHandlerFunc(middlewareA, middlewareB)(handler)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatal("Could not create request")
}
recorder := httptest.NewRecorder()
handlerWithMiddleware.ServeHTTP(recorder, req)
if recorder.Body.String() != "hello world\ngood bye world\n" {
t.Errorf("middleware did not write response body correctly: %s", recorder.Body.String())
}
if recorder.Code != http.StatusOK {
t.Errorf("handler did not write status code correctly: %d", recorder.Code)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment