Skip to content

Instantly share code, notes, and snippets.

@bantonsson
Created March 3, 2015 12:49
Show Gist options
  • Save bantonsson/881f831db93ec474f9bd to your computer and use it in GitHub Desktop.
Save bantonsson/881f831db93ec474f9bd to your computer and use it in GitHub Desktop.
Akka HTTP 1.0-M4 File upload using a form
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
import akka.actor.ActorSystem
import akka.http.Http
import akka.http.Http.IncomingConnection
import akka.http.model.HttpEntity
import akka.http.model.Multipart.FormData
import akka.http.server.{ Directives, Route, RoutingSetup }
import akka.http.unmarshalling.Unmarshal
import akka.stream.{ ActorFlowMaterializer, FlowMaterializer }
import akka.stream.scaladsl.{ Keep, Flow, Sink, Source }
import akka.testkit.TestKit
import akka.util.ByteString
import com.typesafe.config.{ Config, ConfigFactory }
import java.io.{ OutputStream, BufferedOutputStream, FileOutputStream, File }
import java.net.InetSocketAddress
import java.nio.channels.ServerSocketChannel
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{ Milliseconds, Span }
import org.scalatest.{ BeforeAndAfterAll, Matchers, WordSpecLike }
import scala.annotation.tailrec
import scala.concurrent.duration._
import scala.concurrent.{ Await, ExecutionContext, Future }
import scala.sys.process._
import scala.util.Random
object FormUpload {
val debug = true
def temporaryServerAddress(interface: String = "127.0.0.1"): InetSocketAddress = {
val serverSocket = ServerSocketChannel.open()
try {
serverSocket.socket.bind(new InetSocketAddress(interface, 0))
val port = serverSocket.socket.getLocalPort
new InetSocketAddress(interface, port)
} finally serverSocket.close()
}
def temporaryServerAddressAndPort(interface: String = "127.0.0.1"): (String, Int) = {
val socketAddress = temporaryServerAddress(interface)
socketAddress.getAddress.getHostAddress -> socketAddress.getPort
}
def withHttpServer(route: Route)(block: (String, Int) ⇒ Unit)(
implicit system: ActorSystem, materializer: ActorFlowMaterializer, setup: RoutingSetup): Unit = {
val (address, port) = temporaryServerAddressAndPort()
val connectionSource = Http().bind(address, port)
val handleSink = Flow[IncomingConnection].toMat(Sink.foreach(_.handleWith(route)))(Keep.right)
val (bindingFuture, completionFuture) = connectionSource.toMat(handleSink)(Keep.both).run()
val binding = Await.result(bindingFuture, 5 seconds)
try {
block(address, port)
} finally {
binding.unbind()
}
}
val blobPartName = "theBlob"
case class BlobPart(filename: String, source: Source[ByteString, Unit])
def readBlob(bodyParts: Source[FormData.BodyPart, Unit])(implicit mat: FlowMaterializer, ec: ExecutionContext): Future[String] = {
def readBlob(blob: BlobPart): Long = {
if (blob.filename.nonEmpty) {
Await.result(blob.source.runFold(0L) {
case (sum, bytes) ⇒
val newSum = sum + bytes.size
if (debug) println(s"Total $newSum after reading ${bytes.size}")
newSum
}, maxDuration)
} else 0
}
val blobParts = bodyParts.collect {
case part @ FormData.BodyPart(name, entitiy, _, _) ⇒
if (name == blobPartName && part.filename.isDefined)
BlobPart(part.filename.get, entitiy.dataBytes)
else {
// ignore form parts that we don't recognize
// (potential DOS since it will read all data
entitiy.getDataBytes.runWith(Sink.ignore())
BlobPart("", Source.empty())
}
}
blobParts.map(readBlob).map(_.toString + " ").runFold("")(_ + _)
}
def createBlobOfSize(size: Long): String = {
val bytesSize: Int = 1024
val bytes = Array.ofDim[Byte](bytesSize)
@tailrec
def writeRandomBytes(random: Random, remaining: Long, output: OutputStream): Unit = remaining match {
case r if r > 0 ⇒
random.nextBytes(bytes)
val writeSize = if (r > bytesSize) bytesSize else r.toInt
output.write(bytes, 0, writeSize)
writeRandomBytes(random, r - writeSize, output)
case r if r == 0 ⇒ // done
case r ⇒ throw new IllegalArgumentException(s"Cant write less than 0 bytes [$remaining]")
}
val rnd = new Random(size)
val name = rnd.nextLong.toHexString
val temp = File.createTempFile(name, ".tmp")
val out = new BufferedOutputStream(new FileOutputStream(temp))
try {
writeRandomBytes(rnd, size, out)
} finally {
out.close()
}
temp.getCanonicalPath
}
val testConfig: Config = ConfigFactory.parseString("""
akka.event-handlers = ["akka.testkit.TestEventListener"]
akka.loglevel = DEBUG""")
val maxDuration = 10 seconds
}
class FormUpload extends WordSpecLike with Matchers with BeforeAndAfterAll with Directives with Timeouts {
import FormUpload._
implicit val system = ActorSystem(getClass.getSimpleName, testConfig)
var blobPath: String = ""
override def beforeAll {
blobPath = createBlobOfSize(1024 * 1024 * 2.3 toLong)
}
final override def afterAll {
new File(blobPath).delete()
TestKit.shutdownActorSystem(system)
}
import system.dispatcher
implicit val materializer = ActorFlowMaterializer()
"An HTTP Server" should {
"Accept a simple large multipart form upload" in {
withHttpServer(
// format: OFF
path("form") {
post {
extractRequest { request =>
complete(loadSimpleForm(request.entity))
}
}
// format: ON
}) { (address, port) ⇒
failAfter(Span(maxDuration.toMillis, Milliseconds)) {
uploadSimpleForm(address, port) should be(0)
}
}
}
}
private def uploadSimpleForm(address: String, port: Int): Int = {
val command = List(
"curl",
"-v", // verbose
"-H", "Expect:", // remove Expect: 100-continue
"--form", s"""ignoredFormField=@"$blobPath"""",
"--form", s"""$blobPartName=@"$blobPath"""",
s"http://$address:$port/form")
if (debug) println(s"About to execute: ${command.mkString(" ")}")
command.!
}
private def loadSimpleForm(entity: HttpEntity): Future[String] =
for {
multiPartFormData ← Unmarshal(entity).to[FormData]
blobResult ← readBlob(multiPartFormData.parts)
} yield "Read the blob " + blobResult
}
@turb
Copy link

turb commented Mar 26, 2015

Hello Björn, thanks for this piece of code.

I made it work, however, I quickly ran into:

Illegal request, responding with status '413 Request Entity Too Large': Request Content-Length 637876264 exceeds the configured limit of 104857600

So I changed akka.http.server.parsing.max-content-length to have a greater value.

However, if I increase the size of the blob I am sending (with httpie), I quickly run into:

http: error: MemoryError:

It seems the Http component 1) parses the whole request (=> OOM) then 2) streams it down.

Am I missing something? Is there any way to stream in the whole chain?

Thanks again,

Sylvain

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment