Skip to content

Instantly share code, notes, and snippets.

@skiz
Created April 16, 2016 07:00
Show Gist options
  • Save skiz/41190df62ea94dfcb3762edaaaa634a5 to your computer and use it in GitHub Desktop.
Save skiz/41190df62ea94dfcb3762edaaaa634a5 to your computer and use it in GitHub Desktop.
gin-gonic/contrib/sessions based CSRF with back button support, security enhancements, and code clarity.
// Package csrf provides session based CSRF protection for the gin framework.
// This library is a heavily modified version of https://github.com/utrack/gin-csrf
// The primary differences of this library are that tokens are persisted for the
// life of the session (back button support), the secret is not stored in any
// form on the client side, and much of the code has been refactored for clarity.
package csrf
import (
"crypto/sha1"
"encoding/base64"
"io"
"github.com/dchest/uniuri"
"github.com/gin-gonic/contrib/sessions"
"github.com/gin-gonic/gin"
)
const (
csrfFormKey = "_csrf"
csrfSaltName = "csrfSalt"
csrfSecretName = "csrfSecret"
csrfTokenName = "csrfToken"
defaultSecret = "changeMeWithOptions"
)
var defaultIgnoreMethods = []string{"GET", "HEAD", "OPTIONS"}
var defaultErrorFunc = func(c *gin.Context) {
c.String(403, "CSRF token mismatch")
}
var defaultTokenGetter = func(c *gin.Context) string {
r := c.Request
if t := r.FormValue(csrfFormKey); len(t) > 0 {
return t
} else if t := r.URL.Query().Get(csrfFormKey); len(t) > 0 {
return t
} else if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 {
return t
} else if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 {
return t
}
return ""
}
// Options to configure the CSRF middleware.
type Options struct {
Secret string
IgnoreMethods []string
ErrorFunc gin.HandlerFunc
TokenGetter func(c *gin.Context) string
}
func hash(secret, salt string) string {
h := sha1.New()
io.WriteString(h, salt+"-"+secret)
hash := base64.URLEncoding.EncodeToString(h.Sum(nil))
return hash
}
func inArray(arr []string, value string) bool {
for _, v := range arr {
if v == value {
return true
}
}
return false
}
// Middleware validates a session based CSRF token.
func Middleware(options Options) gin.HandlerFunc {
if options.Secret == "" {
options.Secret = defaultSecret
}
if options.IgnoreMethods == nil {
options.IgnoreMethods = defaultIgnoreMethods
}
if options.ErrorFunc == nil {
options.ErrorFunc = defaultErrorFunc
}
if options.TokenGetter == nil {
options.TokenGetter = defaultTokenGetter
}
return func(c *gin.Context) {
session := sessions.Default(c)
c.Set(csrfSecretName, options.Secret)
if inArray(options.IgnoreMethods, c.Request.Method) {
c.Next()
return
}
salt, ok := session.Get(csrfSaltName).(string)
if !ok || len(salt) == 0 {
options.ErrorFunc(c)
c.Abort()
return
}
token := options.TokenGetter(c)
if hash(options.Secret, salt) == token {
c.Next()
return
}
options.ErrorFunc(c)
c.Abort()
return
}
}
// GetToken returns the CSRF token for the session, or generates a new one.
func GetToken(c *gin.Context) string {
session := sessions.Default(c)
token, ok := session.Get(csrfTokenName).(string)
if ok && len(token) > 0 {
return token
}
return GenerateToken(c)
}
// GenerateToken sets and returns a new CSRF token for the current session.
func GenerateToken(c *gin.Context) string {
session := sessions.Default(c)
salt := uniuri.New()
token := hash(c.MustGet(csrfSecretName).(string), salt)
session.Set(csrfTokenName, token)
session.Set(csrfSaltName, salt)
session.Save()
return token
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment