Skip to content

Instantly share code, notes, and snippets.

@jdunck
Forked from BradWhittington/s3_multipart.py
Last active May 15, 2019 13:55
Show Gist options
  • Save jdunck/4393349 to your computer and use it in GitHub Desktop.
Save jdunck/4393349 to your computer and use it in GitHub Desktop.
upload to s3 with streaming, multi-part, threaded upload, with key rollover as you pass the 4gb limit, with adjustable buffering and pooling. Don't forget to call uploader.close() when you're done.
from multiprocessing.pool import ThreadPool
import logging, os, threading
from StringIO import StringIO
import boto.s3
logger = logging.getLogger('s3upload')
class MultiPartUploader:
upload_part = 0
max_size = 3000000000
def __init__(self, access, secret, bucket_name, key_format, buf_size=None, upload_pool_size=10):
self.lock = threading.Lock()
self.total_uploaded = 0
self.connection = boto.connect_s3(access, secret)
self.bucket = self.connection.get_bucket(bucket_name)
self.key_format = key_format
self.key_sequence = 1
self.mp = None
self.buf_size = max(buf_size, 5242880)
self.buffer = StringIO()
self.upload_pool = ThreadPool(upload_pool_size)
self.upload_results = []
def init_mp(self):
if not self.mp is None:
self.close()
with self.lock:
self.mp = self.bucket.initiate_multipart_upload(self.key_name())
self.upload_part = 0
def key_name(self):
return self.key_format % self.key_sequence
def handle_key_rollover(self):
if self.mp is None:
self.init_mp()
if self.total_uploaded / self.key_sequence < self.max_size:
return
self.key_sequence += 1
key_name = self.key_name()
logger.info('rolling over to next key %s', key_name)
self.init_mp()
def write(self,s):
self.total_uploaded += len(s)
self.handle_key_rollover()
self.buffer.write(s)
if self.buffer.len > self.buf_size:
self.flush()
def async_upload(self, mp, part, buffer):
import random
upload_id = int(random.random() * 10000000)
upload_size = buffer.len
buffer.seek(0)
logger.info("%s: uploading %s as part %s", upload_id, upload_size, part)
mp.upload_part_from_file(buffer, part)
logger.info("%s: upload done", upload_id)
buffer.close()
def wait_for_completion(self, timeout_seconds):
while self.upload_results:
logger.info('waiting for %s pending results', len(self.upload_results))
next_result = self.upload_results.pop()
try:
next_result.get(timeout_seconds)
except:
return False
return True
def flush(self):
if self.buffer.len:
self.upload_part+=1
logger.info('flushing %s to part %s; total: %s', self.buffer.len, self.upload_part, self.total_uploaded)
self.upload_results.append(
self.upload_pool.apply_async(self.async_upload,
[self.mp, self.upload_part, self.buffer])
)
self.buffer = StringIO()
def close(self, timeout_seconds=10):
with self.lock:
mp = self.mp
self.flush()
self.wait_for_completion(timeout_seconds)
mp.complete_upload()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment