-
-
Save fedesilva/8f10b9aa19d6ea830505 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com> | |
*/ | |
import akka.actor.ActorSystem | |
import akka.http.Http | |
import akka.http.Http.IncomingConnection | |
import akka.http.model.HttpEntity | |
import akka.http.model.Multipart.FormData | |
import akka.http.server.{ Directives, Route, RoutingSetup } | |
import akka.http.unmarshalling.Unmarshal | |
import akka.stream.{ ActorFlowMaterializer, FlowMaterializer } | |
import akka.stream.scaladsl.{ Keep, Flow, Sink, Source } | |
import akka.testkit.TestKit | |
import akka.util.ByteString | |
import com.typesafe.config.{ Config, ConfigFactory } | |
import java.io.{ OutputStream, BufferedOutputStream, FileOutputStream, File } | |
import java.net.InetSocketAddress | |
import java.nio.channels.ServerSocketChannel | |
import org.scalatest.concurrent.Timeouts | |
import org.scalatest.time.{ Milliseconds, Span } | |
import org.scalatest.{ BeforeAndAfterAll, Matchers, WordSpecLike } | |
import scala.annotation.tailrec | |
import scala.concurrent.duration._ | |
import scala.concurrent.{ Await, ExecutionContext, Future } | |
import scala.sys.process._ | |
import scala.util.Random | |
object FormUpload { | |
val debug = true | |
def temporaryServerAddress(interface: String = "127.0.0.1"): InetSocketAddress = { | |
val serverSocket = ServerSocketChannel.open() | |
try { | |
serverSocket.socket.bind(new InetSocketAddress(interface, 0)) | |
val port = serverSocket.socket.getLocalPort | |
new InetSocketAddress(interface, port) | |
} finally serverSocket.close() | |
} | |
def temporaryServerAddressAndPort(interface: String = "127.0.0.1"): (String, Int) = { | |
val socketAddress = temporaryServerAddress(interface) | |
socketAddress.getAddress.getHostAddress -> socketAddress.getPort | |
} | |
def withHttpServer(route: Route)(block: (String, Int) ⇒ Unit)( | |
implicit system: ActorSystem, materializer: ActorFlowMaterializer, setup: RoutingSetup): Unit = { | |
val (address, port) = temporaryServerAddressAndPort() | |
val connectionSource = Http().bind(address, port) | |
val handleSink = Flow[IncomingConnection].toMat(Sink.foreach(_.handleWith(route)))(Keep.right) | |
val (bindingFuture, completionFuture) = connectionSource.toMat(handleSink)(Keep.both).run() | |
val binding = Await.result(bindingFuture, 5 seconds) | |
try { | |
block(address, port) | |
} finally { | |
binding.unbind() | |
} | |
} | |
val blobPartName = "theBlob" | |
case class BlobPart(filename: String, source: Source[ByteString, Unit]) | |
def readBlob(bodyParts: Source[FormData.BodyPart, Unit])(implicit mat: FlowMaterializer, ec: ExecutionContext): Future[String] = { | |
def readBlob(blob: BlobPart): Long = { | |
if (blob.filename.nonEmpty) { | |
Await.result(blob.source.runFold(0L) { | |
case (sum, bytes) ⇒ | |
val newSum = sum + bytes.size | |
if (debug) println(s"Total $newSum after reading ${bytes.size}") | |
newSum | |
}, maxDuration) | |
} else 0 | |
} | |
val blobParts = bodyParts.collect { | |
case part @ FormData.BodyPart(name, entitiy, _, _) ⇒ | |
if (name == blobPartName && part.filename.isDefined) | |
BlobPart(part.filename.get, entitiy.dataBytes) | |
else { | |
// ignore form parts that we don't recognize | |
// (potential DOS since it will read all data | |
entitiy.getDataBytes.runWith(Sink.ignore()) | |
BlobPart("", Source.empty()) | |
} | |
} | |
blobParts.map(readBlob).map(_.toString + " ").runFold("")(_ + _) | |
} | |
def createBlobOfSize(size: Long): String = { | |
val bytesSize: Int = 1024 | |
val bytes = Array.ofDim[Byte](bytesSize) | |
@tailrec | |
def writeRandomBytes(random: Random, remaining: Long, output: OutputStream): Unit = remaining match { | |
case r if r > 0 ⇒ | |
random.nextBytes(bytes) | |
val writeSize = if (r > bytesSize) bytesSize else r.toInt | |
output.write(bytes, 0, writeSize) | |
writeRandomBytes(random, r - writeSize, output) | |
case r if r == 0 ⇒ // done | |
case r ⇒ throw new IllegalArgumentException(s"Cant write less than 0 bytes [$remaining]") | |
} | |
val rnd = new Random(size) | |
val name = rnd.nextLong.toHexString | |
val temp = File.createTempFile(name, ".tmp") | |
val out = new BufferedOutputStream(new FileOutputStream(temp)) | |
try { | |
writeRandomBytes(rnd, size, out) | |
} finally { | |
out.close() | |
} | |
temp.getCanonicalPath | |
} | |
val testConfig: Config = ConfigFactory.parseString(""" | |
akka.event-handlers = ["akka.testkit.TestEventListener"] | |
akka.loglevel = DEBUG""") | |
val maxDuration = 10 seconds | |
} | |
class FormUpload extends WordSpecLike with Matchers with BeforeAndAfterAll with Directives with Timeouts { | |
import FormUpload._ | |
implicit val system = ActorSystem(getClass.getSimpleName, testConfig) | |
var blobPath: String = "" | |
override def beforeAll { | |
blobPath = createBlobOfSize(1024 * 1024 * 2.3 toLong) | |
} | |
final override def afterAll { | |
new File(blobPath).delete() | |
TestKit.shutdownActorSystem(system) | |
} | |
import system.dispatcher | |
implicit val materializer = ActorFlowMaterializer() | |
"An HTTP Server" should { | |
"Accept a simple large multipart form upload" in { | |
withHttpServer( | |
// format: OFF | |
path("form") { | |
post { | |
extractRequest { request => | |
complete(loadSimpleForm(request.entity)) | |
} | |
} | |
// format: ON | |
}) { (address, port) ⇒ | |
failAfter(Span(maxDuration.toMillis, Milliseconds)) { | |
uploadSimpleForm(address, port) should be(0) | |
} | |
} | |
} | |
} | |
private def uploadSimpleForm(address: String, port: Int): Int = { | |
val command = List( | |
"curl", | |
"-v", // verbose | |
"-H", "Expect:", // remove Expect: 100-continue | |
"--form", s"""ignoredFormField=@"$blobPath"""", | |
"--form", s"""$blobPartName=@"$blobPath"""", | |
s"http://$address:$port/form") | |
if (debug) println(s"About to execute: ${command.mkString(" ")}") | |
command.! | |
} | |
private def loadSimpleForm(entity: HttpEntity): Future[String] = | |
for { | |
multiPartFormData ← Unmarshal(entity).to[FormData] | |
blobResult ← readBlob(multiPartFormData.parts) | |
} yield "Read the blob " + blobResult | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment