Skip to content

Instantly share code, notes, and snippets.

@deadcheat
Created October 5, 2016 18:56
Show Gist options
  • Save deadcheat/635fce6e413af653a9d41121b0fdfee8 to your computer and use it in GitHub Desktop.
Save deadcheat/635fce6e413af653a9d41121b0fdfee8 to your computer and use it in GitHub Desktop.
CORS Middleware so much inspired by labstack/echo/middleware/cors.go for goadesign/goa framework
package cors
import (
"net/http"
"strconv"
"strings"
"github.com/goadesign/goa"
"golang.org/x/net/context"
)
type (
// GoaCORSConfig config struct
GoaCORSConfig struct {
Skipper Skipper
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
AllowCredentials bool
ExposeHeaders []string
MaxAge int
}
)
var (
// DefaultGoaCORSConfig is the default CORS middleware config.
DefaultGoaCORSConfig = GoaCORSConfig{
Skipper: defaultSkipper,
AllowOrigins: []string{"*"},
AllowMethods: []string{GET, HEAD, PUT, PATCH, POST, DELETE},
}
)
const (
// DELETE HTTP Methods
DELETE = "DELETE"
// GET HTTP Methods
GET = "GET"
// HEAD HTTP Methods
HEAD = "HEAD"
// OPTIONS HTTP Methods
OPTIONS = "OPTIONS"
// PATCH HTTP Methods
PATCH = "PATCH"
// POST HTTP Methods
POST = "POST"
// PUT HTTP Methods
PUT = "PUT"
)
const (
// HeaderVary "Vary"
HeaderVary = "Vary"
// HeaderOrigin "Origin"
HeaderOrigin = "Origin"
// HeaderAccessControlRequestMethod "Access-Control-Request-Method"
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
// HeaderAccessControlRequestHeaders "Access-Control-Request-Headers"
HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
// HeaderAccessControlAllowOrigin Access-Control-Allow-Origin"
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
// HeaderAccessControlAllowMethods "Access-Control-Allow-Methods"
HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
// HeaderAccessControlAllowHeaders "Access-Control-Allow-Headers"
HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
// HeaderAccessControlAllowCredentials "Access-Control-Allow-Credentials"
HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
// HeaderAccessControlExposeHeaders "Access-Control-Expose-Headers"
HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
// HeaderAccessControlMaxAge "Access-Control-Max-Age"
HeaderAccessControlMaxAge = "Access-Control-Max-Age"
// HeaderContentType "Content-Type"
HeaderContentType = "Content-Type"
)
// GoaCORS check from default config
func GoaCORS(service *goa.Service) goa.Middleware {
return GoaCORSWithConfig(service, DefaultGoaCORSConfig)
}
// GoaCORSWithConfig check cors-header
func GoaCORSWithConfig(service *goa.Service, conf GoaCORSConfig) goa.Middleware {
if conf.Skipper == nil {
conf.Skipper = DefaultGoaCORSConfig.Skipper
}
if len(conf.AllowOrigins) == 0 {
conf.AllowOrigins = DefaultGoaCORSConfig.AllowOrigins
}
if len(conf.AllowMethods) == 0 {
conf.AllowMethods = DefaultGoaCORSConfig.AllowMethods
}
allowMethods := strings.Join(conf.AllowMethods, ",")
allowHeaders := strings.Join(conf.AllowHeaders, ",")
exposeHeaders := strings.Join(conf.ExposeHeaders, ",")
maxAge := strconv.Itoa(conf.MaxAge)
return func(h goa.Handler) goa.Handler {
return func(c context.Context, rw http.ResponseWriter, req *http.Request) error {
// Skipper
if conf.Skipper(c, rw, req) {
return h(c, rw, req)
}
origin := req.Header.Get(HeaderOrigin)
_, hasOrigin := req.Header[HeaderOrigin]
// isOriginEmpty := (origin == "")
// Check allowed origins
allowedOrigin := ""
for _, o := range conf.AllowOrigins {
if o == "*" || o == origin {
allowedOrigin = o
break
}
}
// Simple request
if req.Method != OPTIONS {
rw.Header().Add(HeaderVary, HeaderOrigin)
if !hasOrigin || allowedOrigin == "" {
return h(c, rw, req)
}
rw.Header().Set(HeaderAccessControlAllowOrigin, allowedOrigin)
if conf.AllowCredentials {
rw.Header().Set(HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
rw.Header().Set(HeaderAccessControlExposeHeaders, exposeHeaders)
}
return h(c, rw, req)
}
// Preflight request
rw.Header().Add(HeaderVary, HeaderOrigin)
rw.Header().Add(HeaderVary, HeaderAccessControlRequestMethod)
rw.Header().Add(HeaderVary, HeaderAccessControlRequestHeaders)
if !hasOrigin || allowedOrigin == "" {
return h(c, rw, req)
}
rw.Header().Set(HeaderAccessControlAllowOrigin, allowedOrigin)
rw.Header().Set(HeaderAccessControlAllowMethods, allowMethods)
if conf.AllowCredentials {
rw.Header().Set(HeaderAccessControlAllowCredentials, "true")
}
if allowHeaders != "" {
rw.Header().Set(HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := req.Header.Get(HeaderAccessControlRequestHeaders)
if h != "" {
rw.Header().Set(HeaderAccessControlAllowHeaders, h)
}
}
if conf.MaxAge > 0 {
rw.Header().Set(HeaderAccessControlMaxAge, maxAge)
}
return service.Send(c, http.StatusNoContent, http.StatusText(http.StatusNoContent))
}
}
}
package cors_test
import (
"net/http"
"testing"
"app/middleware/cors"
"golang.org/x/net/context"
"github.com/goadesign/goa"
. "github.com/smartystreets/goconvey/convey"
)
func TestCORS(t *testing.T) {
var ctx context.Context
var req *http.Request
var rw http.ResponseWriter
var service *goa.Service
Convey("no origin header", t, func() {
service = newService(nil)
req, _ = http.NewRequest(cors.GET, "/", nil)
rw = newTestResponseWriter()
ctx = newContext(service, rw, req, nil)
var newCtx context.Context
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
newCtx = ctx
return service.Send(ctx, http.StatusOK, "ok")
}
t := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{
AllowCredentials: true,
})(h)
err := t(ctx, rw, req)
So(err, ShouldBeNil)
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "")
})
Convey("Empty origin header", t, func() {
service = newService(nil)
req, _ = http.NewRequest(cors.GET, "/", nil)
req.Header.Set(cors.HeaderOrigin, "")
rw = newTestResponseWriter()
ctx = newContext(service, rw, req, nil)
var newCtx context.Context
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
newCtx = ctx
return service.Send(ctx, http.StatusOK, "ok")
}
t := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{
AllowCredentials: true,
})(h)
err := t(ctx, rw, req)
So(err, ShouldBeNil)
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "*")
})
Convey("Wildcard Origin", t, func() {
service = newService(nil)
req, _ = http.NewRequest(cors.GET, "/", nil)
req.Header.Set(cors.HeaderOrigin, "localhost")
rw = newTestResponseWriter()
ctx = newContext(service, rw, req, nil)
var newCtx context.Context
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
newCtx = ctx
return service.Send(ctx, http.StatusOK, "ok")
}
testee := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{
AllowCredentials: true,
AllowOrigins: []string{"*"},
})(h)
err := testee(ctx, rw, req)
So(err, ShouldBeNil)
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "*")
})
Convey("Simple Request", t, func() {
service = newService(nil)
req, _ = http.NewRequest(cors.GET, "/", nil)
req.Header.Set(cors.HeaderOrigin, "localhost")
rw = newTestResponseWriter()
ctx = newContext(service, rw, req, nil)
var newCtx context.Context
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
newCtx = ctx
return service.Send(ctx, http.StatusOK, "ok")
}
testee := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{
AllowCredentials: true,
AllowOrigins: []string{"localhost"},
})(h)
err := testee(ctx, rw, req)
So(err, ShouldBeNil)
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "localhost")
})
Convey("Preflight request", t, func() {
service = newService(nil)
req, _ = http.NewRequest(cors.OPTIONS, "/", nil)
req.Header.Set(cors.HeaderOrigin, "localhost")
req.Header.Set(cors.HeaderContentType, "application/json")
rw = newTestResponseWriter()
ctx = newContext(service, rw, req, nil)
var newCtx context.Context
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
newCtx = ctx
return service.Send(ctx, http.StatusOK, "ok")
}
testee := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{
AllowCredentials: true,
AllowOrigins: []string{"localhost"},
MaxAge: 3600,
})(h)
err := testee(ctx, rw, req)
So(err, ShouldBeNil)
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "localhost")
So(rw.Header().Get(cors.HeaderAccessControlAllowMethods), ShouldNotBeEmpty)
So(rw.Header().Get(cors.HeaderAccessControlAllowCredentials), ShouldEqual, "true")
So(rw.Header().Get(cors.HeaderAccessControlMaxAge), ShouldEqual, "3600")
})
Convey("not allowed header", t, func() {
service = newService(nil)
req, _ = http.NewRequest(cors.GET, "/", nil)
req.Header.Set(cors.HeaderOrigin, "localhost")
rw = newTestResponseWriter()
ctx = newContext(service, rw, req, nil)
var newCtx context.Context
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
newCtx = ctx
return service.Send(ctx, http.StatusOK, "ok")
}
testee := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{
AllowCredentials: true,
AllowOrigins: []string{"example.com"},
})(h)
err := testee(ctx, rw, req)
So(err, ShouldBeNil)
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "")
})
}
package cors
import (
"net/http"
"golang.org/x/net/context"
)
type (
// Skipper condition func for skip middleware action
Skipper func(c context.Context, rw http.ResponseWriter, req *http.Request) bool
)
// defaultSkipper default skipper will always returns false
func defaultSkipper(c context.Context, rw http.ResponseWriter, req *http.Request) bool {
return false
}
package cors_test
import (
"net/http"
"net/url"
"github.com/goadesign/goa"
"golang.org/x/net/context"
)
// Helper that sets up a "working" service
func newService(logger goa.LogAdapter) *goa.Service {
service := goa.New("test")
service.Encoder.Register(goa.NewJSONEncoder, "*/*")
service.Decoder.Register(goa.NewJSONDecoder, "*/*")
service.WithLogger(logger)
return service
}
// Creates a test context
func newContext(service *goa.Service, rw http.ResponseWriter, req *http.Request, params url.Values) context.Context {
ctrl := service.NewController("test")
return goa.NewContext(ctrl.Context, rw, req, params)
}
type logEntry struct {
Msg string
Data []interface{}
}
type testLogger struct {
Context []interface{}
InfoEntries []logEntry
ErrorEntries []logEntry
}
func (t *testLogger) Info(msg string, data ...interface{}) {
e := logEntry{msg, append(t.Context, data...)}
t.InfoEntries = append(t.InfoEntries, e)
}
func (t *testLogger) Error(msg string, data ...interface{}) {
e := logEntry{msg, append(t.Context, data...)}
t.ErrorEntries = append(t.ErrorEntries, e)
}
func (t *testLogger) New(data ...interface{}) goa.LogAdapter {
t.Context = append(t.Context, data...)
return t
}
type testResponseWriter struct {
ParentHeader http.Header
Body []byte
Status int
}
func newTestResponseWriter() *testResponseWriter {
h := make(http.Header)
return &testResponseWriter{ParentHeader: h}
}
func (t *testResponseWriter) Header() http.Header {
return t.ParentHeader
}
func (t *testResponseWriter) Write(b []byte) (int, error) {
t.Body = append(t.Body, b...)
return len(b), nil
}
func (t *testResponseWriter) WriteHeader(s int) {
t.Status = s
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment