Skip to content

Instantly share code, notes, and snippets.

@zygm0nt
Forked from ayosec/CORSDirectives.scala
Created April 8, 2014 22:08
Show Gist options
  • Save zygm0nt/10200572 to your computer and use it in GitHub Desktop.
Save zygm0nt/10200572 to your computer and use it in GitHub Desktop.
package foo.bar
import spray.routing._
import spray.http._
import spray.http.StatusCodes.Forbidden
// See https://developer.mozilla.org/en-US/docs/HTTP/Access_control_CORS
case class Origin(origin: String) extends HttpHeader {
def name = "Origin"
def lowercaseName = "origin"
def value = origin
}
case class `Access-Control-Allow-Origin`(origin: String) extends HttpHeader {
def name = "Access-Control-Allow-Origin"
def lowercaseName = "access-control-allow-origin"
def value = origin
}
case class `Access-Control-Allow-Credentials`(allowed: Boolean) extends HttpHeader {
def name = "Access-Control-Allow-Credentials"
def lowercaseName = "access-control-allow-credentials"
def value = if(allowed) "true" else "false"
}
trait CORSDirectives { this: HttpService =>
def respondWithCORSHeaders(origin: String) =
respondWithHeaders(
`Access-Control-Allow-Origin`(origin),
`Access-Control-Allow-Credentials`(true))
def corsFilter(origin: String)(route: Route) =
if(origin == "*")
respondWithCORSHeaders("*")(route)
else
optionalHeaderValueByName("Origin") {
case None => route
case Some(clientOrigin) =>
if(origin == clientOrigin)
respondWithCORSHeaders(origin)(route)
else
complete(Forbidden, Nil, "Invalid origin") // Maybe, a Rejection will fit better
}
}
package foo.bar
import org.scalatest._
import org.scalatest.matchers.MustMatchers
import akka.actor.ActorSystem
import spray.testkit.ScalatestRouteTest
import spray.routing._
import spray.http._
import spray.http.HttpHeaders._
import spray.http.StatusCodes._
class CORSSpec extends WordSpec
with ScalatestRouteTest
with MustMatchers
with HttpService
with CORSDirectives
{
lazy val actorRefFactory = ActorSystem()
def testRequest(origin: Option[String], filter: String)(checks: => Unit) =
origin.map {
o => Get().withHeaders(List(Origin(o)))
} getOrElse {
Get()
} ~> corsFilter(filter)(complete("OK")) ~> check { checks }
"The CORS filter" must {
"accept any request with *" in {
testRequest(Some(""), "*") { status must be (OK) }
testRequest(Some("http://a.tld"), "*") { status must be (OK) }
testRequest(None, "*") { status must be (OK) }
}
"accept only valid requests with the same origin" in {
val filter = "http://a.tld"
testRequest(Some(filter), filter) { status must be (OK) }
testRequest(None, filter) { status must be (OK) } // Non-CORS request are valid
testRequest(Some(""), filter) { status must be (Forbidden) }
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment