Skip to content

Instantly share code, notes, and snippets.

@casualjim
Created October 22, 2011 21:20
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save casualjim/1306507 to your computer and use it in GitHub Desktop.
Save casualjim/1306507 to your computer and use it in GitHub Desktop.
CORS Support for scalatra
package backchat
package web
import javax.servlet.http.{ HttpServletResponse, HttpServletRequest }
import org.scalatra._
import collection.JavaConversions._
object CORSSupport {
val ORIGIN_HEADER: String = "Origin"
val ACCESS_CONTROL_REQUEST_METHOD_HEADER: String = "Access-Control-Request-Method"
val ACCESS_CONTROL_REQUEST_HEADERS_HEADER: String = "Access-Control-Request-Headers"
val ACCESS_CONTROL_ALLOW_ORIGIN_HEADER: String = "Access-Control-Allow-Origin"
val ACCESS_CONTROL_ALLOW_METHODS_HEADER: String = "Access-Control-Allow-Methods"
val ACCESS_CONTROL_ALLOW_HEADERS_HEADER: String = "Access-Control-Allow-Headers"
val ACCESS_CONTROL_MAX_AGE_HEADER: String = "Access-Control-Max-Age"
val ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER: String = "Access-Control-Allow-Credentials"
// private val ACCESS_CONTROL_EXPOSE_HEADERS_HEADER = "Access-Control-Expose-Headers"
private val ANY_ORIGIN: String = "*"
private val SIMPLE_HEADERS = List(ORIGIN_HEADER.toUpperCase(ENGLISH), "ACCEPT", "ACCEPT-LANGUAGE", "CONTENT-LANGUAGE")
private val SIMPLE_CONTENT_TYPES = List("APPLICATION/X-WWW-FORM-URLENCODED", "MULTIPART/FORM-DATA", "TEXT/PLAIN")
val CORS_HEADERS = List(
ORIGIN_HEADER,
ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER,
ACCESS_CONTROL_ALLOW_HEADERS_HEADER,
ACCESS_CONTROL_ALLOW_METHODS_HEADER,
ACCESS_CONTROL_ALLOW_ORIGIN_HEADER,
ACCESS_CONTROL_MAX_AGE_HEADER,
ACCESS_CONTROL_REQUEST_HEADERS_HEADER,
ACCESS_CONTROL_REQUEST_METHOD_HEADER)
// private val SIMPLE_RESPONSE_HEADERS = List("CACHE-CONTROL", "CONTENT-LANGUAGE", "EXPIRES", "LAST-MODIFIED", "PRAGMA", "CONTENT-TYPE")
}
trait CORSSupport extends Handler { self: ScalatraKernel with Logging ⇒
import CORSSupport._
protected def corsConfig = Config.CORS
private val anyOriginAllowed: Boolean = corsConfig.allowedOrigins.contains(ANY_ORIGIN)
private val allowedOrigins = corsConfig.allowedOrigins
private val allowedMethods = corsConfig.allowedMethods
private val allowedHeaders = corsConfig.allowedHeaders
private val preflightMaxAge: Int = corsConfig.preflightMaxAge
private val allowCredentials: Boolean = corsConfig.allowCredentials
logger debug "Enabled CORS Support with:\nallowedOrigins: %s\nallowedMethods: %s\nallowedHeaders: %s".format(
allowedOrigins mkString ", ",
allowedMethods mkString ", ",
allowedHeaders mkString ", ")
protected def handlePreflightRequest() {
logger trace "handling preflight request"
// 5.2.7
augmentSimpleRequest()
// 5.2.8
if (preflightMaxAge > 0) response.setHeader(ACCESS_CONTROL_MAX_AGE_HEADER, preflightMaxAge.toString)
// 5.2.9
response.setHeader(ACCESS_CONTROL_ALLOW_METHODS_HEADER, allowedMethods mkString ",")
// 5.2.10
response.setHeader(ACCESS_CONTROL_ALLOW_HEADERS_HEADER, allowedHeaders mkString ",")
response.flushBuffer()
response.getOutputStream.flush()
}
protected def augmentSimpleRequest() {
val hdr = if (anyOriginAllowed && !allowCredentials) ANY_ORIGIN else request.getHeader(ORIGIN_HEADER)
response.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, hdr)
if (allowCredentials) response.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER, "true")
/*
if (allowedHeaders.nonEmpty) {
val hdrs = allowedHeaders.filterNot(hn => SIMPLE_RESPONSE_HEADERS.contains(hn.toUpperCase(ENGLISH))).mkString(",")
response.addHeader(ACCESS_CONTROL_ALLOW_HEADERS_HEADER, hdrs)
}
*/
}
private def originMatches = // 6.2.2
anyOriginAllowed || (allowedOrigins contains request.getHeader(ORIGIN_HEADER))
private def isEnabled =
!("Upgrade".equalsIgnoreCase(request.getHeader("Connection")) &&
"WebSocket".equalsIgnoreCase(request.getHeader("Upgrade"))) &&
!requestPath.contains("eb_ping") // don't do anything for the ping endpoint
private def isValidRoute: Boolean = routes.matchingMethods.nonEmpty
private def isPreflightRequest = {
val isCors = isCORSRequest
val validRoute = isValidRoute
val isPreflight = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD_HEADER).isNotBlank
val enabled = isEnabled
val matchesOrigin = originMatches
val methodAllowd = allowsMethod
val allowsHeaders = headersAreAllowed
val result = isCors && validRoute && isPreflight && enabled && matchesOrigin && methodAllowd && allowsHeaders
logger trace "This is a preflight validation check. valid? %s".format(result)
logger trace "cors? %s, route? %s, preflight? %s, enabled? %s, origin? %s, method? %s, header? %s".format(
isCors, validRoute, isPreflight, enabled, matchesOrigin, methodAllowd, allowsHeaders)
result
}
private def isCORSRequest = { // 6.x.1
val h = request.getHeader(ORIGIN_HEADER)
val result = h.isNotBlank
if (!result) logger trace ("No origin found in the request")
else logger trace ("We found the origin: %s".format(h))
result
}
private def isSimpleHeader(header: String) = {
val ho = header.toOption
ho.isDefined && (ho forall { h ⇒
val hu = h.toUpperCase(ENGLISH)
SIMPLE_HEADERS.contains(hu) || (hu == "CONTENT-TYPE" &&
SIMPLE_CONTENT_TYPES.exists(request.getContentType.toUpperCase(ENGLISH).startsWith))
})
}
private def allOriginsMatch = { // 6.1.2
val h = request.getHeader(ORIGIN_HEADER).toOption
h.isDefined && h.get.split(" ").nonEmpty && h.get.split(" ").forall(allowedOrigins.contains)
}
private def isSimpleRequest = {
val isCors = isCORSRequest
val enabled = isEnabled
val allOrigins = allOriginsMatch
val res = isCors && enabled && allOrigins && request.getHeaderNames.forall(isSimpleHeader)
logger trace "This is a simple request: %s, because: %s, %s, %s".format(res, isCors, enabled, allOrigins)
res
}
private def allowsMethod = { // 5.2.3 and 5.2.5
val accessControlRequestMethod = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD_HEADER)
logger.trace("%s is %s" format (ACCESS_CONTROL_REQUEST_METHOD_HEADER, accessControlRequestMethod))
val result = accessControlRequestMethod.isNotBlank && allowedMethods.contains(accessControlRequestMethod.toUpperCase(ENGLISH))
logger.trace("Method %s is %s among allowed methods %s".format(accessControlRequestMethod, if (result) "" else " not", allowedMethods))
result
}
private def headersAreAllowed = { // 5.2.4 and 5.2.6
val accessControlRequestHeaders = request.getHeader(ACCESS_CONTROL_REQUEST_HEADERS_HEADER).toOption
logger.trace("%s is %s".format(ACCESS_CONTROL_REQUEST_HEADERS_HEADER, accessControlRequestHeaders))
val ah = (allowedHeaders ++ CORS_HEADERS).map(_.trim.toUpperCase(ENGLISH))
val result = accessControlRequestHeaders forall { hdr ⇒
val hdrs = hdr.split(",").map(_.trim.toUpperCase(ENGLISH))
logger.debug("Headers [%s]".format(hdrs))
(hdrs.nonEmpty && hdrs.forall { h ⇒ ah.contains(h) }) || isSimpleHeader(hdr)
}
logger.trace("Headers [%s] are %s among allowed headers %s".format(
accessControlRequestHeaders getOrElse "No headers", if (result) "" else " not", ah))
result
}
abstract override def handle(req: HttpServletRequest, res: HttpServletResponse) {
_request.withValue(req) {
logger trace "the headers are: %s".format(req.getHeaderNames.mkString(", "))
_response.withValue(res) {
request.method match {
case Options if isPreflightRequest ⇒ {
handlePreflightRequest()
}
case Get | Post | Head if isSimpleRequest ⇒ {
augmentSimpleRequest()
super.handle(req, res)
}
case _ if isCORSRequest ⇒ {
augmentSimpleRequest()
super.handle(req, res)
}
case _ ⇒ super.handle(req, res)
}
}
}
}
}
package backchat
package web
package tests
import org.scalatra.test.specs2.ScalatraSpec
import org.scalatra.ScalatraServlet
class CORSSupportSpec extends ScalatraSpec {
addServlet(new ScalatraServlet with Logging with CORSSupport {
override protected lazy val corsConfig =
CORSConfig(List("http://www.example.com"), List("GET", "HEAD", "POST"), "X-Requested-With,Authorization,Content-Type,Accept,Origin".split(","), true)
get("/") {
"OK"
}
}, "/*")
def is =
"The CORS support should" ^
"augment a valid simple request" ! context.validSimpleRequest ^
"not touch a regular request" ! context.dontTouchRegularRequest ^
"respond to a valid preflight request" ! context.validPreflightRequest ^
"respond to a valid preflight request with headers" ! context.validPreflightRequestWithHeaders ^ end
object context {
def validSimpleRequest = {
get("/", headers = Map(CORSSupport.ORIGIN_HEADER -> "http://www.example.com")) {
response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com"
}
}
def dontTouchRegularRequest = {
get("/") {
response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must beNull
}
}
def validPreflightRequest = {
options("/", headers = Map(CORSSupport.ORIGIN_HEADER -> "http://www.example.com", CORSSupport.ACCESS_CONTROL_REQUEST_METHOD_HEADER -> "GET", "Content-Type" -> "application/json")) {
response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com"
}
}
def validPreflightRequestWithHeaders = {
val hdrs = Map(
CORSSupport.ORIGIN_HEADER -> "http://www.example.com",
CORSSupport.ACCESS_CONTROL_REQUEST_METHOD_HEADER -> "GET",
CORSSupport.ACCESS_CONTROL_REQUEST_HEADERS_HEADER -> "Origin, Authorization, Accept",
"Content-Type" -> "application/json")
options("/", headers = hdrs) {
response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com"
}
}
}
}
@ikennaokpala
Copy link

@casualjim what is CORS.. and what are you trying to achieve exactly it appears to be configuring certian.. pros

@casualjim
Copy link
Author

CORS is Cross-Origin-Resource-Sharing, somthing browsers implement to allow you to make cross domain requests to servers with which you have a "trusted" relationship.

@ikennaokpala
Copy link

Thanks ivan for that explanation.. at first i easily took it for CQRS (Command Query Responsibility Segregation).. but a deeper look at the code suggested otherwise..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment