Skip to content

Instantly share code, notes, and snippets.

@anzdaddy
Last active June 5, 2024 23:28
Show Gist options
  • Save anzdaddy/96ef571f57e1bf7a1de73a7ee48d2174 to your computer and use it in GitHub Desktop.
Save anzdaddy/96ef571f57e1bf7a1de73a7ee48d2174 to your computer and use it in GitHub Desktop.
Injected middleware idea

Injected Middleware

This sample explores the idea that middlewares can be injected into an http.Handler instead of wrapping it.

The current model is:

type mwf func(http.Handler) http.Handler

func oldRecoverer(h http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		defer func() {
			if e := recover(); e != nil {
				http.Error(w, fmt.Sprint(e), http.StatusInternalServerError)
			}
		}()
		h.ServeHTTP(w, r)
	})
}
⋮
http.ListenAndServe(":8000", oldRecoverer(myHandler))

Under the injection model, the body of the handler calls the middleware as follows:

func myHandler(w http.ResponseWriter, r *http.Request) {
	defer recoverer(&w, &r)()
	⋮
}
⋮
http.ListenAndServe(":8000", myHandler)

Discussion

The key advantage of this approach is that it gives handlers fine-grained control over middleware execution. Each handler can:

  1. Cherry-pick middlewares or add extra middlewares to a common set.
  2. Run code before middlewares.
  3. Conditionally run middlewares.
  4. Set a breakpoint at the start of a specific handler and step into its middlewares.

The mwf.wrap method is a conventional middleware function. Going the other way (convert an existing middleware to an mwf) isn't possible.

type mwf func(w *http.ResponseWriter, r **http.Request) func()
func fromWrapper(func(h http.Handler) http.Handler) {
func (m mwf) wrap(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer m(&w, &r)()
h.ServeHTTP(w, r)
})
}
func recoverer(w *http.ResponseWriter, _ **http.Request) func() {
return func() {
if e := recover(); e != nil {
http.Error(*w, fmt.Sprint(e), http.StatusInternalServerError)
}
}
}
func analyzer(w *http.ResponseWriter, r **http.Request) func() {
cw := &countResponseBytesWriter{ResponseWriter: *w}
*w = cw
return func() {
log.Printf("responseBytes: %d", cw.n)
}
}
type countResponseBytesWriter struct {
http.ResponseWriter
n int64
}
func (w *countResponseBytesWriter) Write(p []byte) (int, error) {
n, err := w.ResponseWriter.Write(p)
w.n += int64(n)
return n, err
}
func mws(m ...mwf) mwf {
return func(w *http.ResponseWriter, r **http.Request) func() {
deferred := make([]func(), 0, len(m))
for _, f := range m {
deferred = append(deferred, f(w, r))
}
return func() {
for _, d := range deferred {
defer d()
}
}
}
}
func myHandler(w http.ResponseWriter, r *http.Request) {
defer mws(recoverer, analyzer)(&w, &r)()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment