Skip to content

Instantly share code, notes, and snippets.

@annismckenzie
Created February 26, 2017 12:56
Show Gist options
  • Save annismckenzie/8d1586e31cfcd4edc5e00ad5b6397519 to your computer and use it in GitHub Desktop.
Save annismckenzie/8d1586e31cfcd4edc5e00ad5b6397519 to your computer and use it in GitHub Desktop.
Example of a gin middleware to handle sessions and gin's response writer
package responseWriter
import (
"bufio"
"bytes"
"net"
"net/http"
"net/http/httptest"
)
type BufferingResponseWriter struct {
http.ResponseWriter
size int
}
func (brw *BufferingResponseWriter) WriteHeaderNow() { /* don't write anything here */ }
func (brw *BufferingResponseWriter) Bytes() []byte {
return brw.ResponseWriter.(*httptest.ResponseRecorder).Body.Bytes()
}
func (brw *BufferingResponseWriter) Clear() {
brw.ResponseWriter = &httptest.ResponseRecorder{HeaderMap: make(http.Header), Body: new(bytes.Buffer), Code: 0}
}
func (brw *BufferingResponseWriter) Status() int {
return brw.ResponseWriter.(*httptest.ResponseRecorder).Code
}
func (brw *BufferingResponseWriter) Size() int {
return brw.size
}
func (brw *BufferingResponseWriter) Write(data []byte) (int, error) {
n, err := brw.ResponseWriter.Write(data)
brw.size += n
return n, err
}
func (brw *BufferingResponseWriter) WriteHeader(code int) {
brw.ResponseWriter.WriteHeader(code)
}
func (brw *BufferingResponseWriter) WriteString(s string) (int, error) {
n, err := brw.ResponseWriter.(*httptest.ResponseRecorder).WriteString(s)
brw.size += n
return n, err
}
func (brw *BufferingResponseWriter) Written() bool {
return false
}
// Implements the http.Hijacker interface
func (brw *BufferingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return brw.ResponseWriter.(http.Hijacker).Hijack()
}
// Implements the http.CloseNotify interface
func (brw *BufferingResponseWriter) CloseNotify() <-chan bool {
return brw.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Implements the http.Flush interface
func (brw *BufferingResponseWriter) Flush() {
brw.ResponseWriter.(*httptest.ResponseRecorder).Flush()
}
package middleware
import (
"bytes"
"encoding/base64"
"net/http"
"time"
"our-git-server/company/project/config"
pb "our-git-server/company/project/config/session"
rw "our-git-server/company/project/responseWriter" // see implementation in the file below
"our-git-server/company/project/view"
"github.com/gin-gonic/gin"
"github.com/golang/protobuf/proto"
)
// Session loads and persists a session for the application to use and not care about the plumbing.
// This middleware needs to be pretty high up the middleware chain to be useful.
func Session() gin.HandlerFunc {
return func(c *gin.Context) {
var (
sessionBytesBefore []byte
session pb.Session // struct that holds our session data, (de)serialized to and from a Protocol buffer
bufferingWriter = &rw.BufferingResponseWriter{}
originalWriter = c.Writer
)
bufferingWriter.Clear()
// you can drop this, of course
if sessionCookie, err := c.Request.Cookie(config.SessionName); err == nil { // load session from the cookie
if sessionBytesBefore, err = base64.StdEncoding.DecodeString(sessionCookie.Value); err == nil {
proto.Unmarshal(sessionBytesBefore, &session)
}
}
c.Set(config.SessionName, &session)
c.Writer = bufferingWriter
c.Next() // continue along the middleware chain and handling the request
// this happens after the middleware chain and the controller action are done
status := bufferingWriter.Status()
if status == 0 && len(bufferingWriter.Header()) == 0 && len(bufferingWriter.Bytes()) == 0 { // nothing was written and no handler run, so just return here
status = 404
view.AbortWithHTTPError(c, status, nil) // cannot paste this here, it renders a Jet template for 404s
}
c.Writer = originalWriter
// persist session to the cookie
if c.Writer != nil {
session.UserId = ""
if sessionBytesAfter, err := proto.Marshal(&session); err == nil {
if !bytes.Equal(sessionBytesBefore, sessionBytesAfter) {
sessionCookie := &http.Cookie{Name: config.SessionName, Path: "/", HttpOnly: true, Expires: time.Now().Add(60 * 24 * time.Hour)}
sessionCookie.Value = base64.StdEncoding.EncodeToString(sessionBytesAfter)
http.SetCookie(c.Writer, sessionCookie)
}
}
// now write out the response
for k, v := range bufferingWriter.Header() { // we copy the original headers first
c.Writer.Header()[k] = v
}
if status == 0 {
status = 200
}
c.Writer.WriteHeader(status) // then set the status (which writes the headers as well)
c.Writer.Write(bufferingWriter.Bytes()) // and finally transfer the response
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment