Skip to content

Instantly share code, notes, and snippets.

@zgiber
Created January 15, 2016 15:53
Show Gist options
  • Save zgiber/4b64f82d2a79a1702709 to your computer and use it in GitHub Desktop.
Save zgiber/4b64f82d2a79a1702709 to your computer and use it in GitHub Desktop.
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"golang.org/x/net/context"
)
// Handler is a context aware http request handler
type Handler func(c context.Context, w http.ResponseWriter, r *http.Request)
// Middleware is a context aware middleware function signature
type Middleware func(c context.Context, w http.ResponseWriter, r *http.Request) context.Context
//
// EXAMPLE MIDDLEWARE
//
// Header is a simple middleware example to set response's header values
func Header(key, value string) Middleware {
return func(c context.Context, w http.ResponseWriter, r *http.Request) context.Context {
w.Header().Set(key, value)
return c
}
}
// Token is a simple useless middleware example with error handling
func Token() Middleware {
return func(c context.Context, w http.ResponseWriter, r *http.Request) context.Context {
token := r.Header.Get("Authorization")
if len(token) == 0 {
err := NewHTTPError(http.StatusBadRequest, "Missing token")
return contextWithError(c, err)
}
return context.WithValue(c, "token", token)
}
}
// HeaderLogger logs request header
func HeaderLogger() Middleware {
return func(c context.Context, w http.ResponseWriter, r *http.Request) context.Context {
var output string
for k, v := range r.Header {
values := strings.Join(v, ",")
output = strings.Join([]string{output, k, " ", values, "\n"}, "")
}
log.Println(output)
return c
}
}
// for playing around with stuff
func helloMiddeware() Middleware {
return func(c context.Context, w http.ResponseWriter, r *http.Request) context.Context {
_, err := fmt.Fprintln(w, "Sorry for not being JSON")
if err != nil {
log.Println(err)
}
return c
}
}
func main() {
c := context.Background()
// USAGE in http server
jsonAPI := Chain(
Token(),
//Token(), // a quick test if error is handled, and subsequent middleware are not executed
Header("Content-Type", "application/json"),
)
debugJSONAPI := Chain(HeaderLogger(), jsonAPI)
http.HandleFunc("/main", Handle(c, exampleHandler, debugJSONAPI)) // middleware can be chained in the handler constructor too
log.Fatal(http.ListenAndServe(":8080", nil))
}
//
// HANDLER
//
func exampleHandler(c context.Context, w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"userID": 1234,
"email": "john@example.com",
}
// middleware can be used in a handler too (?)
helloMiddeware()(c, w, r)
enc := json.NewEncoder(w)
enc.Encode(response)
}
// HTTPError contains code and message
// and implements error interface
type HTTPError struct {
code int
message string
}
func (err *HTTPError) Error() string {
return fmt.Sprintf("%v %s", err.code, err.message)
}
// NewHTTPError returns a HTTPError as an error interface
func NewHTTPError(code int, message string) *HTTPError {
return &HTTPError{code, message}
}
func contextWithError(c context.Context, err *HTTPError) context.Context {
var cancel context.CancelFunc // cancel the context
if err != nil {
c = context.WithValue(c, "errorMessage", err)
c, cancel = context.WithCancel(c)
// c, _ = context.WithCancel(c)
cancel()
}
return c
}
func newErrorHandler(c context.Context) Handler {
return func(c context.Context, w http.ResponseWriter, r *http.Request) {
// handle http error
if err, ok := c.Value("errorMessage").(*HTTPError); ok {
http.Error(w, err.Error(), err.code)
} else {
log.Println(err)
http.Error(w, "wrong middleware error type", http.StatusInternalServerError)
}
}
}
// Chain multiple middleware
func Chain(mw ...Middleware) Middleware {
return func(c context.Context, w http.ResponseWriter, r *http.Request) context.Context {
var newCtx context.Context
for _, m := range mw {
newCtx = m(c, w, r)
if newCtx.Err() != nil {
return newCtx
}
c = newCtx
}
return c
}
}
// Handle returns a http.HandleFunc
func Handle(c context.Context, h Handler, mw ...Middleware) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
newCtx, cancel := context.WithCancel(c) //
defer cancel() //
m := Chain(mw...) // chain middlewares into a single middleware
newCtx = m(newCtx, w, r) // run middleware update context (we don't need the original)
switch newCtx.Err() {
case context.Canceled, context.DeadlineExceeded:
newErrorHandler(newCtx)(newCtx, w, r)
return
}
h(newCtx, w, r)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment