Skip to content

Instantly share code, notes, and snippets.

@a-canela
Last active May 15, 2023 00:34
Show Gist options
  • Save a-canela/4cbbe20b08ce1fa92ff373d5b60ac9ef to your computer and use it in GitHub Desktop.
Save a-canela/4cbbe20b08ce1fa92ff373d5b60ac9ef to your computer and use it in GitHub Desktop.
Directory download (sync) through boto3 supporting include-only pattern
import os
import boto3
from fnmatch import fnmatch
from botocore.config import Config
from traceback import format_exc
from multiprocessing import Process, Queue, Pipe
from concurrent.futures import ThreadPoolExecutor
class DownloadException(Exception):
pass
class S3Exception(Exception):
pass
class S3DownloadException(S3Exception, DownloadException):
pass
class ProcessWithExceptionPiping(Process):
"""Multiprocessing Process with exception piping"""
def __init__(
self,
*args,
raise_at_child: bool = False,
raise_at_parent: bool = True,
**kwargs,
):
"""Initialize process with exception piping.
Args:
raise_at_child (bool, optional): Raise exception at child process. Defaults to False.
raise_at_parent (bool, optional): Raise exception at parent process. Defaults to True.
"""
self.raise_at_child = raise_at_child
self.raise_at_parent = raise_at_parent
self.exception_traceback = None
self.parent_connection, self.child_connection = Pipe()
super().__init__(*args, **kwargs)
def run(self):
try:
super().run()
self.child_connection.send(None)
except BaseException as exception:
traceback = format_exc()
self.child_connection.send((exception, traceback))
if self.raise_at_child:
raise exception
def join(self, *args, **kwargs):
super().join(*args, **kwargs)
if self.raise_at_parent:
exception_traceback = self.get_exception_traceback()
if exception_traceback:
exception, traceback = exception_traceback
raise type(exception)(f'{traceback}\n{exception}')
def get_exception_traceback(self):
if self.parent_connection.poll():
self.exception_traceback = self.parent_connection.recv()
return self.exception_traceback
class S3Client:
"""AWS S3 boto3-based client for downloading/syncing folders"""
def __init__(
self,
bucket: str = '',
max_attempts: int = 10,
max_pool_connections: int = 100,
max_download_workers: int = 20,
) -> None:
"""AWS S3 boto3-based client constructor.
Note: Implementation only valid for UNIX.
Args:
bucket (str): Bucket path.
max_attempts (int, optional): Maximum retry attemps.
Defaults to 10
max_pool_connections (int, optional) Maximum number of concurrent requests to aws s3.
Defaults to 100
max_download_workers (int, optional) Maximum number of workers for downloading files.
Defaults to 20
Attributes:
success_paths (list): List of tuples of successfully uploaded paths.
"""
self.success_paths = []
self.bucket = bucket.replace('s3://', '')
self.max_attempts = max_attempts
config = Config(retries={
'mode': 'standard',
'max_attempts': max_attempts,
}, max_pool_connections=max_pool_connections)
self.client = boto3.client('s3', config=config)
self.max_download_workers = max_download_workers
self.paginator = self.client.get_paginator('list_objects_v2')
def download_folder(
self,
dst_dir: str,
prefix: str,
pattern: str = None,
bucket: str = '',
) -> None:
"""Download a folder (recursively) from S3.
TODO: Benchmark this (Process + ThreadPoolExecutor) vs only ThreadPoolExecutor.
Test cases: small/large list (> 10000 files), need/no need to re-download.
Args:
dst_dir (str): Destination local folder.
prefix (str): File/folder S3 key.
pattern (str, optional): Including filter pattern. For more information, check
fnmatch library. Examples: '*/my_folder/image_???_*.png', '*.json'.
Defaults to None.
bucket (str, optional): File/folder S3 bucket.
Defaults to self.bucket.
"""
self.downloader_queue = Queue()
self.downloader = ProcessWithExceptionPiping(target=self._downloader)
self.downloader.start()
try:
dst_dir = dst_dir.rstrip('/')
prefix = prefix.rstrip('/')
bucket = bucket or self.bucket
self._enqueue_downloads(
dst_dir=dst_dir,
prefix=prefix,
pattern=pattern,
bucket=bucket,
)
finally:
self.downloader_queue.put(None)
self.downloader.join()
def _enqueue_downloads(
self,
dst_dir: str,
prefix: str,
pattern: str = None,
bucket: str = '',
) -> None:
"""Enqueue files to be downloaded by the downloader background process.
Args:
dst_dir (str): Destination local folder.
prefix (str): File/folder S3 key.
pattern (str, optional): Including filter pattern. Defaults to None.
bucket (str, optional): File/folder S3 bucket. Defaults to ''.
"""
root_dir = f'{dst_dir}/{prefix}'
for page in self.paginator.paginate(Bucket=bucket, Prefix=prefix):
files = page.get('Contents', ())
file_data = [
(
file.get('LastModified').timestamp(),
file.get('Size'),
file.get('Key'),
file.get('Key').replace(f'{prefix}/', ''),
)
for file in files
]
if pattern:
file_data = list(filter(lambda file: fnmatch(file[3], pattern), file_data))
sub_dirs = set([os.path.dirname(file[3]) for file in file_data])
for sub_dir in sub_dirs:
os.makedirs(f'{root_dir}/{sub_dir}', exist_ok=True)
for timestamp, size, key, sub_path in file_data:
self.downloader_queue.put((bucket, key, f'{root_dir}/{sub_path}', timestamp, size))
def _downloader(self) -> None:
"""Downloader background process"""
futures = []
with ThreadPoolExecutor(max_workers=self.max_download_workers) as executor:
args = self.downloader_queue.get()
while args:
futures.append(executor.submit(self._download_file, *args))
args = self.downloader_queue.get()
for future in futures:
future.result()
def _download_file(
self,
bucket: str,
key: str,
dst_path: str,
timestamp: float,
size: int,
) -> None:
"""Download a file and assign the provided timestamp to it.
Args:
bucket (str): File S3 bucket.
key (str): File S3 key.
dst_path (str): File destination local path.
timestamp (float): File S3 timestamp.
size (int): File S3 size.
"""
if os.path.exists(dst_path):
dst_file_stat = os.stat(dst_path)
if timestamp == dst_file_stat.st_mtime and size == dst_file_stat.st_size:
return
self.client.download_file(bucket, key, dst_path)
if os.path.exists(dst_path):
os.utime(dst_path, (timestamp, timestamp))
else:
raise S3DownloadException(f'Error downloading: s3://{bucket}/{key} => {dst_path}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment