Skip to content

Instantly share code, notes, and snippets.

@zrzka
Created August 21, 2019 13:39
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zrzka/8d7a31d39206f31e9cbafc6456546b60 to your computer and use it in GitHub Desktop.
Save zrzka/8d7a31d39206f31e9cbafc6456546b60 to your computer and use it in GitHub Desktop.
How to save an actix multipart field to s3 using rusoto_s3?
use std::convert::TryFrom;
use actix_multipart::Multipart;
use actix_web::{error, middleware, web, App, Error, HttpResponse, HttpServer};
use futures::{Future, IntoFuture, Stream};
use s3::S3MultipartUpload;
mod s3;
fn upload(multipart: Multipart) -> impl Future<Item = HttpResponse, Error = Error> {
multipart
.map_err(error::ErrorInternalServerError)
.and_then(|f| S3MultipartUpload::try_from(f))
.and_then(|f| f.into_future())
.collect()
.map(|result| HttpResponse::Ok().json(result))
.map_err(|e| {
println!("failed: {}", e);
e
})
}
fn index() -> HttpResponse {
let html = r#"<html>
<head><title>Upload Test</title></head>
<body>
<form target="/" method="post" enctype="multipart/form-data">
<input type="file" name="file"/>
<input type="submit" value="Submit"></button>
</form>
</body>
</html>"#;
HttpResponse::Ok().body(html)
}
fn main() -> std::io::Result<()> {
std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info");
env_logger::init();
HttpServer::new(|| {
App::new().wrap(middleware::Logger::default()).service(
web::resource("/")
.route(web::get().to(index))
.route(web::post().to_async(upload)),
)
})
.bind("127.0.0.1:8080")?
.run()
}
use std::convert::TryFrom;
use actix_multipart::Field;
use actix_web::{error, Error};
use bytes::{Bytes, BytesMut};
use futures::stream;
use futures::{Async, Future, Stream};
use rusoto_core::{ByteStream, Region};
use rusoto_s3::{
AbortMultipartUploadRequest, CompleteMultipartUploadRequest, CompletedMultipartUpload,
CompletedPart, CreateMultipartUploadRequest, S3Client, UploadPartRequest, S3,
};
// S3 minimum part size
const S3_MIN_PART_SIZE: usize = 5_242_880;
// Make this configurable
const S3_BUCKET: &str = "images.zrzka.dev";
const S3_REGION: Region = Region::EuCentral1;
// Not particularly useful, it makes it easier to understand
type UploadId = String;
type ETag = String;
type Location = String;
// Helpers to create task futures
fn create_multipart_upload(
client: &impl S3,
bucket: &str,
key: &str,
) -> impl Future<Item = UploadId, Error = Error> {
let request = CreateMultipartUploadRequest {
bucket: bucket.into(),
key: key.into(),
..Default::default()
};
client
.create_multipart_upload(request)
.map_err(|e| error::ErrorInternalServerError(e))
.and_then(|mut output| match output.upload_id.take() {
Some(upload_id) => Ok(upload_id),
None => Err(error::ErrorInternalServerError("Unable to get upload_id")),
})
}
fn upload_part(
client: &impl S3,
bucket: &str,
key: &str,
upload_id: &str,
part_number: i64,
body: Bytes,
) -> impl Future<Item = ETag, Error = Error> {
let request = UploadPartRequest {
bucket: bucket.into(),
key: key.into(),
upload_id: upload_id.into(),
part_number,
content_length: Some(body.len() as i64),
body: Some(ByteStream::new(stream::once(Ok(body)))),
..Default::default()
};
client
.upload_part(request)
.map_err(|e| error::ErrorInternalServerError(e))
.and_then(|mut output| match output.e_tag.take() {
Some(e_tag) => Ok(e_tag),
None => Err(error::ErrorInternalServerError("Unable to get etag")),
})
}
fn abort_multipart_upload(
client: &impl S3,
bucket: &str,
key: &str,
upload_id: &str,
) -> impl Future<Item = (), Error = Error> {
let request = AbortMultipartUploadRequest {
bucket: bucket.into(),
key: key.into(),
upload_id: upload_id.into(),
..Default::default()
};
client
.abort_multipart_upload(request)
.map(|_| ())
.map_err(|e| error::ErrorInternalServerError(e))
}
fn complete_multipart_upload(
client: &impl S3,
bucket: &str,
key: &str,
upload_id: &str,
e_tags: &[ETag],
) -> impl Future<Item = Location, Error = Error> {
let completed_parts = e_tags
.iter()
.enumerate()
.map(|(idx, e_tag)| CompletedPart {
e_tag: Some(e_tag.clone()),
part_number: Some((idx + 1) as i64),
})
.collect();
let multipart_upload = CompletedMultipartUpload {
parts: Some(completed_parts),
};
let request = CompleteMultipartUploadRequest {
bucket: bucket.into(),
key: key.into(),
upload_id: upload_id.into(),
multipart_upload: Some(multipart_upload),
..Default::default()
};
client
.complete_multipart_upload(request)
.map_err(|e| error::ErrorInternalServerError(e))
.and_then(|mut output| match output.location.take() {
Some(location) => Ok(location),
None => Err(error::ErrorInternalServerError(
"Upload completed, but unable to get location",
)),
})
}
enum Task {
GetUploadId(Box<dyn Future<Item = UploadId, Error = Error>>),
GetNextPart,
UploadPart(Box<dyn Future<Item = ETag, Error = Error>>),
CompleteUpload(Box<dyn Future<Item = Location, Error = Error>>),
AbortUpload(Box<dyn Future<Item = (), Error = Error>>),
}
pub struct S3MultipartUpload {
bucket: String,
key: String,
client: S3Client,
field: FieldS3Parts,
task: Option<Task>,
// Multipart upload identifier
upload_id: Option<String>,
// Next multipart upload part number
next_part_number: i64,
// List of etags of already uploaded parts
e_tags: Vec<ETag>,
// Error to report after multipart upload abort
abort_error: Option<Error>,
}
impl TryFrom<Field> for S3MultipartUpload {
type Error = Error;
fn try_from(value: Field) -> Result<Self, Self::Error> {
let content_disposition = value
.content_disposition()
.ok_or_else(|| error::ErrorInternalServerError("Missing Content-Disposition"))?;
let key = content_disposition
.get_filename()
.map(|f| f.replace(' ', "_").to_string())
.ok_or_else(|| error::ErrorInternalServerError("Missing filename"))?;
Ok(S3MultipartUpload {
bucket: S3_BUCKET.to_string(),
client: S3Client::new(S3_REGION),
key,
next_part_number: 1,
e_tags: vec![],
upload_id: None,
field: FieldS3Parts::new(value),
task: None,
abort_error: None,
})
}
}
impl Future for S3MultipartUpload {
type Item = Location;
type Error = Error;
fn poll(&mut self) -> Result<Async<Self::Item>, Self::Error> {
if self.task.is_none() {
// As a first thing we have to get upload id
let request = create_multipart_upload(&self.client, &self.bucket, &self.key);
self.task = Some(Task::GetUploadId(Box::new(request)));
}
match self.task.take().unwrap() {
Task::GetUploadId(mut request) => match request.poll() {
Ok(Async::Ready(upload_id)) => {
println!("GetUploadId: Async::Ready: {}", upload_id);
self.upload_id = Some(upload_id);
self.task = Some(Task::GetNextPart);
}
Ok(Async::NotReady) => {
println!("GetUploadId: Async::NotReady");
self.task = Some(Task::GetUploadId(request));
return Ok(Async::NotReady);
}
Err(e) => {
println!("GetUploadId: Err: {}", e);
return Err(e);
}
},
Task::GetNextPart => match self.field.poll() {
Ok(Async::Ready(Some(bytes))) => {
println!("GetNextPart: Async::Ready: {}", bytes.len());
let request = upload_part(
&self.client,
&self.bucket,
&self.key,
&self.upload_id.as_ref().unwrap(),
self.next_part_number,
bytes,
);
self.task = Some(Task::UploadPart(Box::new(request)));
}
Ok(Async::Ready(None)) => {
println!("GetNextPart: Async::Ready: no more parts");
let request = complete_multipart_upload(
&self.client,
&self.bucket,
&self.key,
&self.upload_id.as_ref().unwrap(),
&self.e_tags,
);
self.task = Some(Task::CompleteUpload(Box::new(request)));
}
Ok(Async::NotReady) => {
println!("GetNextPart: Async::NotReady");
self.task = Some(Task::GetNextPart);
return Ok(Async::NotReady);
}
Err(e) => {
println!("GetNextPart: Err: {}", e);
self.abort_error = Some(error::ErrorInternalServerError(e));
let request = abort_multipart_upload(
&self.client,
&self.bucket,
&self.key,
&self.upload_id.as_ref().unwrap(),
);
self.task = Some(Task::AbortUpload(Box::new(request)));
}
},
Task::UploadPart(mut request) => match request.poll() {
Ok(Async::Ready(e_tag)) => {
println!("UploadPart: Async::Ready: {}", e_tag);
self.e_tags.push(e_tag);
self.next_part_number += 1;
self.task = Some(Task::GetNextPart);
}
Ok(Async::NotReady) => {
println!("UploadPart: Async::NotReady");
self.task = Some(Task::UploadPart(request));
return Ok(Async::NotReady);
}
Err(e) => {
println!("UploadPart: Err: {}", e);
self.abort_error = Some(error::ErrorInternalServerError(e));
let request = abort_multipart_upload(
&self.client,
&self.bucket,
&self.key,
&self.upload_id.as_ref().unwrap(),
);
self.task = Some(Task::AbortUpload(Box::new(request)));
}
},
Task::CompleteUpload(mut request) => match request.poll() {
Ok(Async::Ready(location)) => {
println!("CompleteUpload: Async::Ready: {}", location);
return Ok(Async::Ready(location));
}
Ok(Async::NotReady) => {
println!("CompleteUpload: Async::NotReady");
self.task = Some(Task::CompleteUpload(Box::new(request)));
return Ok(Async::NotReady);
}
Err(e) => {
println!("CompleteUpload: Err: {}", e);
self.abort_error = Some(error::ErrorInternalServerError(e));
let request = abort_multipart_upload(
&self.client,
&self.bucket,
&self.key,
&self.upload_id.as_ref().unwrap(),
);
self.task = Some(Task::AbortUpload(Box::new(request)));
}
},
Task::AbortUpload(mut request) => match request.poll() {
Ok(Async::Ready(_)) => {
println!("AbortUpload: Async::Ready");
return Err(self.abort_error.take().unwrap());
}
Ok(Async::NotReady) => {
println!("AbortUpload: Async::NotReady");
self.task = Some(Task::AbortUpload(request));
return Ok(Async::NotReady);
}
Err(e) => {
println!("AbortUpload: Err: {}", e);
return Err(self.abort_error.take().unwrap());
}
},
};
self.poll()
}
}
struct FieldS3Parts {
field: Option<Field>,
buffer: BytesMut,
}
impl FieldS3Parts {
fn new(field: Field) -> FieldS3Parts {
FieldS3Parts {
field: Some(field),
buffer: BytesMut::with_capacity(S3_MIN_PART_SIZE),
}
}
fn bytes(&mut self) -> Bytes {
let bytes = self.buffer.take().freeze();
self.buffer.reserve(bytes.len());
bytes
}
}
impl Stream for FieldS3Parts {
type Item = Bytes;
type Error = Error;
fn poll(&mut self) -> Result<Async<Option<Self::Item>>, Self::Error> {
if self.field.is_none() {
return Ok(Async::Ready(None));
}
match self.field.as_mut().unwrap().poll() {
Ok(Async::Ready(Some(bytes))) => {
println!("FieldS3Parts: Async::Ready: {}", bytes.len());
self.buffer.extend(bytes);
if self.buffer.len() >= S3_MIN_PART_SIZE {
println!("FieldS3Parts: buffer filled");
Ok(Async::Ready(Some(self.bytes())))
} else {
println!("FieldS3Parts: buffer too small");
Ok(Async::NotReady)
}
}
Ok(Async::Ready(None)) => {
println!("FieldS3Parts: Async::Ready: no more parts");
self.field = None;
if self.buffer.len() > 0 {
println!("FieldS3Parts: sending remaining buffer");
Ok(Async::Ready(Some(self.bytes())))
} else {
Ok(Async::Ready(None))
}
}
Ok(Async::NotReady) => {
println!("FieldS3Parts: Async::NotReady");
Ok(Async::NotReady)
}
Err(e) => {
println!("FieldS3Parts: Err: {}", e);
Err(error::ErrorInternalServerError(e))
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment