Last active
May 25, 2020 21:14
-
-
Save brackendawson/01c5a1ca06389f908c1f24d14885a0f5 to your computer and use it in GitHub Desktop.
How to make a server too complicated
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"fmt" | |
"io" | |
"log" | |
"math/rand" | |
"net/http" | |
"os" | |
"os/signal" | |
"strconv" | |
"strings" | |
"time" | |
) | |
const ( | |
username = "admin" | |
password = "password123" | |
) | |
func handler(w http.ResponseWriter, r *http.Request) { | |
secs, _ := strconv.Atoi(r.Header.Get("Wait")) | |
select { | |
case <-time.After(time.Second * time.Duration(secs)): | |
case <-r.Context().Done(): | |
getLogger(r).Print(r.Context().Err()) | |
w.WriteHeader(http.StatusInternalServerError) | |
return | |
} | |
getLogger(r).Print("Somebody loves me") | |
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, | |
"http://localhost:1337/this_request_has_the_X-Request-ID", nil) | |
if err != nil { | |
getLogger(r).Print("Failed to make upstream request: ", err) | |
} | |
res, err := client.Do(req) | |
if err != nil { | |
getLogger(r).Print("Upstream request failed: ", err) | |
} else if res.StatusCode != http.StatusOK { | |
getLogger(r).Print("Upstream request returned bad status: ", | |
res.StatusCode) | |
} | |
_, _ = io.WriteString(w, "hi\n") | |
} | |
func main() { | |
shutdown := make(chan struct{}) | |
s := http.Server{ | |
Handler: setRqID(rqLogger(rqTime(rqInProgress(shutdown, auth( | |
mediaType("text/plain", http.HandlerFunc(handler))))))), | |
} | |
s.RegisterOnShutdown(func() { close(shutdown) }) | |
idleConnsClosed := make(chan struct{}) | |
go func() { | |
sigint := make(chan os.Signal, 1) | |
signal.Notify(sigint, os.Interrupt) | |
<-sigint | |
log.Print("Shutting down") | |
if err := s.Shutdown(context.Background()); err != nil { | |
log.Printf("Shutdown failed: %s", err) | |
} | |
close(idleConnsClosed) | |
}() | |
log.Print("Starting server") | |
if err := s.ListenAndServe(); err != http.ErrServerClosed { | |
log.Fatalf("Server failed: %s", err) | |
} | |
<-idleConnsClosed | |
} | |
type contextKey string | |
const ( | |
rqIDKey contextKey = "rqID" | |
loggerKey contextKey = "logger" | |
) | |
func setRqID(next http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
id := r.Header.Get("X-Request-ID") | |
if id == "" { | |
id = strconv.Itoa(rand.Int()) | |
} | |
r = r.WithContext(context.WithValue(r.Context(), rqIDKey, id)) | |
w.Header().Set("X-Request-ID", id) | |
next.ServeHTTP(w, r) | |
}) | |
} | |
var client = &http.Client{ | |
Transport: &transport{http.DefaultTransport}, | |
} | |
type transport struct { | |
http.RoundTripper | |
} | |
func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { | |
if id, ok := r.Context().Value(rqIDKey).(string); ok { | |
r.Header.Set("X-Request-ID", id) | |
} | |
return t.RoundTripper.RoundTrip(r) | |
} | |
func rqTime(next http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
t := time.Now() | |
next.ServeHTTP(w, r) | |
getLogger(r).Print("Request took ", time.Since(t)) | |
}) | |
} | |
func rqLogger(next http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
prefix, ok := r.Context().Value(rqIDKey).(string) | |
if !ok { | |
log.Print("No request ID to prefix logger") | |
} else { | |
prefix += " " | |
} | |
logger := log.New(os.Stderr, prefix, log.Flags()) | |
logger.Printf("Received request %s %s from %s %s", r.Method, r.URL, | |
r.RemoteAddr, r.Header.Get("User-Agent")) | |
r = r.Clone(context.WithValue(r.Context(), loggerKey, logger)) | |
ww := &statusCatcher{w, http.StatusOK} | |
next.ServeHTTP(ww, r) | |
logger.Print("Request completed with status: ", ww.statusCode) | |
}) | |
} | |
type statusCatcher struct { | |
http.ResponseWriter | |
statusCode int | |
} | |
func (s *statusCatcher) WriteHeader(statusCode int) { | |
s.statusCode = statusCode | |
s.ResponseWriter.WriteHeader(statusCode) | |
} | |
func getLogger(r *http.Request) *log.Logger { | |
logger, ok := r.Context().Value(loggerKey).(*log.Logger) | |
if !ok { | |
log.Print("No logger found in request") | |
return log.New(os.Stderr, "", log.Flags()) | |
} | |
return logger | |
} | |
func rqInProgress(shutdown chan struct{}, next http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
go func() { | |
select { | |
case <-shutdown: | |
getLogger(r).Print("Request still in progress") | |
case <-r.Context().Done(): | |
} | |
}() | |
next.ServeHTTP(w, r) | |
}) | |
} | |
func auth(next http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
u, p, ok := r.BasicAuth() | |
if !ok { | |
getLogger(r).Print("No authorization") | |
w.Header().Set("WWW-Authenticate", "Basic") | |
w.WriteHeader(http.StatusUnauthorized) | |
return | |
} | |
if u != username { | |
getLogger(r).Print("Unknown user: ", u) | |
w.WriteHeader(http.StatusUnauthorized) | |
return | |
} | |
if p != password { | |
getLogger(r).Print("Wrong password for ", username) | |
w.WriteHeader(http.StatusUnauthorized) | |
return | |
} | |
next.ServeHTTP(w, r) | |
}) | |
} | |
func mediaType(mediaType string, next http.Handler) http.Handler { | |
if !acceptable(mediaType, mediaType) { | |
panic(fmt.Sprint("Invalid media type", mediaType)) | |
} | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
for _, accept := range r.Header["Accept"] { | |
if acceptable(accept, mediaType) { | |
w.Header().Set("Content-Type", mediaType) | |
next.ServeHTTP(w, r) | |
return | |
} | |
} | |
getLogger(r).Print("No acceptable types: ", r.Header["Accept"]) | |
w.Header().Set("Content-Type", "text/plain") | |
w.WriteHeader(http.StatusNotAcceptable) | |
_, _ = io.WriteString(w, fmt.Sprintf("Media type should be %s\n", | |
mediaType)) | |
}) | |
} | |
func acceptable(accept, mediaType string) bool { | |
mediaTypeParts := strings.Split(mediaType, "/") | |
acceptParts := strings.Split(accept, "/") | |
if len(acceptParts) != 2 { | |
return false | |
} | |
for i := range acceptParts { | |
if acceptParts[i] == "*" { | |
continue | |
} | |
if acceptParts[i] == mediaTypeParts[i] { | |
continue | |
} | |
return false | |
} | |
return true | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment