Created
October 5, 2016 18:56
-
-
Save deadcheat/635fce6e413af653a9d41121b0fdfee8 to your computer and use it in GitHub Desktop.
CORS Middleware so much inspired by labstack/echo/middleware/cors.go for goadesign/goa framework
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cors | |
import ( | |
"net/http" | |
"strconv" | |
"strings" | |
"github.com/goadesign/goa" | |
"golang.org/x/net/context" | |
) | |
type ( | |
// GoaCORSConfig config struct | |
GoaCORSConfig struct { | |
Skipper Skipper | |
AllowOrigins []string | |
AllowMethods []string | |
AllowHeaders []string | |
AllowCredentials bool | |
ExposeHeaders []string | |
MaxAge int | |
} | |
) | |
var ( | |
// DefaultGoaCORSConfig is the default CORS middleware config. | |
DefaultGoaCORSConfig = GoaCORSConfig{ | |
Skipper: defaultSkipper, | |
AllowOrigins: []string{"*"}, | |
AllowMethods: []string{GET, HEAD, PUT, PATCH, POST, DELETE}, | |
} | |
) | |
const ( | |
// DELETE HTTP Methods | |
DELETE = "DELETE" | |
// GET HTTP Methods | |
GET = "GET" | |
// HEAD HTTP Methods | |
HEAD = "HEAD" | |
// OPTIONS HTTP Methods | |
OPTIONS = "OPTIONS" | |
// PATCH HTTP Methods | |
PATCH = "PATCH" | |
// POST HTTP Methods | |
POST = "POST" | |
// PUT HTTP Methods | |
PUT = "PUT" | |
) | |
const ( | |
// HeaderVary "Vary" | |
HeaderVary = "Vary" | |
// HeaderOrigin "Origin" | |
HeaderOrigin = "Origin" | |
// HeaderAccessControlRequestMethod "Access-Control-Request-Method" | |
HeaderAccessControlRequestMethod = "Access-Control-Request-Method" | |
// HeaderAccessControlRequestHeaders "Access-Control-Request-Headers" | |
HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" | |
// HeaderAccessControlAllowOrigin Access-Control-Allow-Origin" | |
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" | |
// HeaderAccessControlAllowMethods "Access-Control-Allow-Methods" | |
HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" | |
// HeaderAccessControlAllowHeaders "Access-Control-Allow-Headers" | |
HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" | |
// HeaderAccessControlAllowCredentials "Access-Control-Allow-Credentials" | |
HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" | |
// HeaderAccessControlExposeHeaders "Access-Control-Expose-Headers" | |
HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" | |
// HeaderAccessControlMaxAge "Access-Control-Max-Age" | |
HeaderAccessControlMaxAge = "Access-Control-Max-Age" | |
// HeaderContentType "Content-Type" | |
HeaderContentType = "Content-Type" | |
) | |
// GoaCORS check from default config | |
func GoaCORS(service *goa.Service) goa.Middleware { | |
return GoaCORSWithConfig(service, DefaultGoaCORSConfig) | |
} | |
// GoaCORSWithConfig check cors-header | |
func GoaCORSWithConfig(service *goa.Service, conf GoaCORSConfig) goa.Middleware { | |
if conf.Skipper == nil { | |
conf.Skipper = DefaultGoaCORSConfig.Skipper | |
} | |
if len(conf.AllowOrigins) == 0 { | |
conf.AllowOrigins = DefaultGoaCORSConfig.AllowOrigins | |
} | |
if len(conf.AllowMethods) == 0 { | |
conf.AllowMethods = DefaultGoaCORSConfig.AllowMethods | |
} | |
allowMethods := strings.Join(conf.AllowMethods, ",") | |
allowHeaders := strings.Join(conf.AllowHeaders, ",") | |
exposeHeaders := strings.Join(conf.ExposeHeaders, ",") | |
maxAge := strconv.Itoa(conf.MaxAge) | |
return func(h goa.Handler) goa.Handler { | |
return func(c context.Context, rw http.ResponseWriter, req *http.Request) error { | |
// Skipper | |
if conf.Skipper(c, rw, req) { | |
return h(c, rw, req) | |
} | |
origin := req.Header.Get(HeaderOrigin) | |
_, hasOrigin := req.Header[HeaderOrigin] | |
// isOriginEmpty := (origin == "") | |
// Check allowed origins | |
allowedOrigin := "" | |
for _, o := range conf.AllowOrigins { | |
if o == "*" || o == origin { | |
allowedOrigin = o | |
break | |
} | |
} | |
// Simple request | |
if req.Method != OPTIONS { | |
rw.Header().Add(HeaderVary, HeaderOrigin) | |
if !hasOrigin || allowedOrigin == "" { | |
return h(c, rw, req) | |
} | |
rw.Header().Set(HeaderAccessControlAllowOrigin, allowedOrigin) | |
if conf.AllowCredentials { | |
rw.Header().Set(HeaderAccessControlAllowCredentials, "true") | |
} | |
if exposeHeaders != "" { | |
rw.Header().Set(HeaderAccessControlExposeHeaders, exposeHeaders) | |
} | |
return h(c, rw, req) | |
} | |
// Preflight request | |
rw.Header().Add(HeaderVary, HeaderOrigin) | |
rw.Header().Add(HeaderVary, HeaderAccessControlRequestMethod) | |
rw.Header().Add(HeaderVary, HeaderAccessControlRequestHeaders) | |
if !hasOrigin || allowedOrigin == "" { | |
return h(c, rw, req) | |
} | |
rw.Header().Set(HeaderAccessControlAllowOrigin, allowedOrigin) | |
rw.Header().Set(HeaderAccessControlAllowMethods, allowMethods) | |
if conf.AllowCredentials { | |
rw.Header().Set(HeaderAccessControlAllowCredentials, "true") | |
} | |
if allowHeaders != "" { | |
rw.Header().Set(HeaderAccessControlAllowHeaders, allowHeaders) | |
} else { | |
h := req.Header.Get(HeaderAccessControlRequestHeaders) | |
if h != "" { | |
rw.Header().Set(HeaderAccessControlAllowHeaders, h) | |
} | |
} | |
if conf.MaxAge > 0 { | |
rw.Header().Set(HeaderAccessControlMaxAge, maxAge) | |
} | |
return service.Send(c, http.StatusNoContent, http.StatusText(http.StatusNoContent)) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cors_test | |
import ( | |
"net/http" | |
"testing" | |
"app/middleware/cors" | |
"golang.org/x/net/context" | |
"github.com/goadesign/goa" | |
. "github.com/smartystreets/goconvey/convey" | |
) | |
func TestCORS(t *testing.T) { | |
var ctx context.Context | |
var req *http.Request | |
var rw http.ResponseWriter | |
var service *goa.Service | |
Convey("no origin header", t, func() { | |
service = newService(nil) | |
req, _ = http.NewRequest(cors.GET, "/", nil) | |
rw = newTestResponseWriter() | |
ctx = newContext(service, rw, req, nil) | |
var newCtx context.Context | |
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { | |
newCtx = ctx | |
return service.Send(ctx, http.StatusOK, "ok") | |
} | |
t := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{ | |
AllowCredentials: true, | |
})(h) | |
err := t(ctx, rw, req) | |
So(err, ShouldBeNil) | |
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "") | |
}) | |
Convey("Empty origin header", t, func() { | |
service = newService(nil) | |
req, _ = http.NewRequest(cors.GET, "/", nil) | |
req.Header.Set(cors.HeaderOrigin, "") | |
rw = newTestResponseWriter() | |
ctx = newContext(service, rw, req, nil) | |
var newCtx context.Context | |
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { | |
newCtx = ctx | |
return service.Send(ctx, http.StatusOK, "ok") | |
} | |
t := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{ | |
AllowCredentials: true, | |
})(h) | |
err := t(ctx, rw, req) | |
So(err, ShouldBeNil) | |
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "*") | |
}) | |
Convey("Wildcard Origin", t, func() { | |
service = newService(nil) | |
req, _ = http.NewRequest(cors.GET, "/", nil) | |
req.Header.Set(cors.HeaderOrigin, "localhost") | |
rw = newTestResponseWriter() | |
ctx = newContext(service, rw, req, nil) | |
var newCtx context.Context | |
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { | |
newCtx = ctx | |
return service.Send(ctx, http.StatusOK, "ok") | |
} | |
testee := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{ | |
AllowCredentials: true, | |
AllowOrigins: []string{"*"}, | |
})(h) | |
err := testee(ctx, rw, req) | |
So(err, ShouldBeNil) | |
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "*") | |
}) | |
Convey("Simple Request", t, func() { | |
service = newService(nil) | |
req, _ = http.NewRequest(cors.GET, "/", nil) | |
req.Header.Set(cors.HeaderOrigin, "localhost") | |
rw = newTestResponseWriter() | |
ctx = newContext(service, rw, req, nil) | |
var newCtx context.Context | |
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { | |
newCtx = ctx | |
return service.Send(ctx, http.StatusOK, "ok") | |
} | |
testee := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{ | |
AllowCredentials: true, | |
AllowOrigins: []string{"localhost"}, | |
})(h) | |
err := testee(ctx, rw, req) | |
So(err, ShouldBeNil) | |
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "localhost") | |
}) | |
Convey("Preflight request", t, func() { | |
service = newService(nil) | |
req, _ = http.NewRequest(cors.OPTIONS, "/", nil) | |
req.Header.Set(cors.HeaderOrigin, "localhost") | |
req.Header.Set(cors.HeaderContentType, "application/json") | |
rw = newTestResponseWriter() | |
ctx = newContext(service, rw, req, nil) | |
var newCtx context.Context | |
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { | |
newCtx = ctx | |
return service.Send(ctx, http.StatusOK, "ok") | |
} | |
testee := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{ | |
AllowCredentials: true, | |
AllowOrigins: []string{"localhost"}, | |
MaxAge: 3600, | |
})(h) | |
err := testee(ctx, rw, req) | |
So(err, ShouldBeNil) | |
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "localhost") | |
So(rw.Header().Get(cors.HeaderAccessControlAllowMethods), ShouldNotBeEmpty) | |
So(rw.Header().Get(cors.HeaderAccessControlAllowCredentials), ShouldEqual, "true") | |
So(rw.Header().Get(cors.HeaderAccessControlMaxAge), ShouldEqual, "3600") | |
}) | |
Convey("not allowed header", t, func() { | |
service = newService(nil) | |
req, _ = http.NewRequest(cors.GET, "/", nil) | |
req.Header.Set(cors.HeaderOrigin, "localhost") | |
rw = newTestResponseWriter() | |
ctx = newContext(service, rw, req, nil) | |
var newCtx context.Context | |
h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { | |
newCtx = ctx | |
return service.Send(ctx, http.StatusOK, "ok") | |
} | |
testee := cors.GoaCORSWithConfig(service, cors.GoaCORSConfig{ | |
AllowCredentials: true, | |
AllowOrigins: []string{"example.com"}, | |
})(h) | |
err := testee(ctx, rw, req) | |
So(err, ShouldBeNil) | |
So(rw.Header().Get(cors.HeaderAccessControlAllowOrigin), ShouldEqual, "") | |
}) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cors | |
import ( | |
"net/http" | |
"golang.org/x/net/context" | |
) | |
type ( | |
// Skipper condition func for skip middleware action | |
Skipper func(c context.Context, rw http.ResponseWriter, req *http.Request) bool | |
) | |
// defaultSkipper default skipper will always returns false | |
func defaultSkipper(c context.Context, rw http.ResponseWriter, req *http.Request) bool { | |
return false | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cors_test | |
import ( | |
"net/http" | |
"net/url" | |
"github.com/goadesign/goa" | |
"golang.org/x/net/context" | |
) | |
// Helper that sets up a "working" service | |
func newService(logger goa.LogAdapter) *goa.Service { | |
service := goa.New("test") | |
service.Encoder.Register(goa.NewJSONEncoder, "*/*") | |
service.Decoder.Register(goa.NewJSONDecoder, "*/*") | |
service.WithLogger(logger) | |
return service | |
} | |
// Creates a test context | |
func newContext(service *goa.Service, rw http.ResponseWriter, req *http.Request, params url.Values) context.Context { | |
ctrl := service.NewController("test") | |
return goa.NewContext(ctrl.Context, rw, req, params) | |
} | |
type logEntry struct { | |
Msg string | |
Data []interface{} | |
} | |
type testLogger struct { | |
Context []interface{} | |
InfoEntries []logEntry | |
ErrorEntries []logEntry | |
} | |
func (t *testLogger) Info(msg string, data ...interface{}) { | |
e := logEntry{msg, append(t.Context, data...)} | |
t.InfoEntries = append(t.InfoEntries, e) | |
} | |
func (t *testLogger) Error(msg string, data ...interface{}) { | |
e := logEntry{msg, append(t.Context, data...)} | |
t.ErrorEntries = append(t.ErrorEntries, e) | |
} | |
func (t *testLogger) New(data ...interface{}) goa.LogAdapter { | |
t.Context = append(t.Context, data...) | |
return t | |
} | |
type testResponseWriter struct { | |
ParentHeader http.Header | |
Body []byte | |
Status int | |
} | |
func newTestResponseWriter() *testResponseWriter { | |
h := make(http.Header) | |
return &testResponseWriter{ParentHeader: h} | |
} | |
func (t *testResponseWriter) Header() http.Header { | |
return t.ParentHeader | |
} | |
func (t *testResponseWriter) Write(b []byte) (int, error) { | |
t.Body = append(t.Body, b...) | |
return len(b), nil | |
} | |
func (t *testResponseWriter) WriteHeader(s int) { | |
t.Status = s | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment