Skip to content

Instantly share code, notes, and snippets.

@brackendawson
Last active May 25, 2020 21:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brackendawson/01c5a1ca06389f908c1f24d14885a0f5 to your computer and use it in GitHub Desktop.
Save brackendawson/01c5a1ca06389f908c1f24d14885a0f5 to your computer and use it in GitHub Desktop.
How to make a server too complicated
package main
import (
"context"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"time"
)
const (
username = "admin"
password = "password123"
)
func handler(w http.ResponseWriter, r *http.Request) {
secs, _ := strconv.Atoi(r.Header.Get("Wait"))
select {
case <-time.After(time.Second * time.Duration(secs)):
case <-r.Context().Done():
getLogger(r).Print(r.Context().Err())
w.WriteHeader(http.StatusInternalServerError)
return
}
getLogger(r).Print("Somebody loves me")
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet,
"http://localhost:1337/this_request_has_the_X-Request-ID", nil)
if err != nil {
getLogger(r).Print("Failed to make upstream request: ", err)
}
res, err := client.Do(req)
if err != nil {
getLogger(r).Print("Upstream request failed: ", err)
} else if res.StatusCode != http.StatusOK {
getLogger(r).Print("Upstream request returned bad status: ",
res.StatusCode)
}
_, _ = io.WriteString(w, "hi\n")
}
func main() {
shutdown := make(chan struct{})
s := http.Server{
Handler: setRqID(rqLogger(rqTime(rqInProgress(shutdown, auth(
mediaType("text/plain", http.HandlerFunc(handler))))))),
}
s.RegisterOnShutdown(func() { close(shutdown) })
idleConnsClosed := make(chan struct{})
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt)
<-sigint
log.Print("Shutting down")
if err := s.Shutdown(context.Background()); err != nil {
log.Printf("Shutdown failed: %s", err)
}
close(idleConnsClosed)
}()
log.Print("Starting server")
if err := s.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("Server failed: %s", err)
}
<-idleConnsClosed
}
type contextKey string
const (
rqIDKey contextKey = "rqID"
loggerKey contextKey = "logger"
)
func setRqID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-ID")
if id == "" {
id = strconv.Itoa(rand.Int())
}
r = r.WithContext(context.WithValue(r.Context(), rqIDKey, id))
w.Header().Set("X-Request-ID", id)
next.ServeHTTP(w, r)
})
}
var client = &http.Client{
Transport: &transport{http.DefaultTransport},
}
type transport struct {
http.RoundTripper
}
func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) {
if id, ok := r.Context().Value(rqIDKey).(string); ok {
r.Header.Set("X-Request-ID", id)
}
return t.RoundTripper.RoundTrip(r)
}
func rqTime(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t := time.Now()
next.ServeHTTP(w, r)
getLogger(r).Print("Request took ", time.Since(t))
})
}
func rqLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
prefix, ok := r.Context().Value(rqIDKey).(string)
if !ok {
log.Print("No request ID to prefix logger")
} else {
prefix += " "
}
logger := log.New(os.Stderr, prefix, log.Flags())
logger.Printf("Received request %s %s from %s %s", r.Method, r.URL,
r.RemoteAddr, r.Header.Get("User-Agent"))
r = r.Clone(context.WithValue(r.Context(), loggerKey, logger))
ww := &statusCatcher{w, http.StatusOK}
next.ServeHTTP(ww, r)
logger.Print("Request completed with status: ", ww.statusCode)
})
}
type statusCatcher struct {
http.ResponseWriter
statusCode int
}
func (s *statusCatcher) WriteHeader(statusCode int) {
s.statusCode = statusCode
s.ResponseWriter.WriteHeader(statusCode)
}
func getLogger(r *http.Request) *log.Logger {
logger, ok := r.Context().Value(loggerKey).(*log.Logger)
if !ok {
log.Print("No logger found in request")
return log.New(os.Stderr, "", log.Flags())
}
return logger
}
func rqInProgress(shutdown chan struct{}, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
go func() {
select {
case <-shutdown:
getLogger(r).Print("Request still in progress")
case <-r.Context().Done():
}
}()
next.ServeHTTP(w, r)
})
}
func auth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok {
getLogger(r).Print("No authorization")
w.Header().Set("WWW-Authenticate", "Basic")
w.WriteHeader(http.StatusUnauthorized)
return
}
if u != username {
getLogger(r).Print("Unknown user: ", u)
w.WriteHeader(http.StatusUnauthorized)
return
}
if p != password {
getLogger(r).Print("Wrong password for ", username)
w.WriteHeader(http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func mediaType(mediaType string, next http.Handler) http.Handler {
if !acceptable(mediaType, mediaType) {
panic(fmt.Sprint("Invalid media type", mediaType))
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, accept := range r.Header["Accept"] {
if acceptable(accept, mediaType) {
w.Header().Set("Content-Type", mediaType)
next.ServeHTTP(w, r)
return
}
}
getLogger(r).Print("No acceptable types: ", r.Header["Accept"])
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusNotAcceptable)
_, _ = io.WriteString(w, fmt.Sprintf("Media type should be %s\n",
mediaType))
})
}
func acceptable(accept, mediaType string) bool {
mediaTypeParts := strings.Split(mediaType, "/")
acceptParts := strings.Split(accept, "/")
if len(acceptParts) != 2 {
return false
}
for i := range acceptParts {
if acceptParts[i] == "*" {
continue
}
if acceptParts[i] == mediaTypeParts[i] {
continue
}
return false
}
return true
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment