Last active
May 15, 2023 00:34
-
-
Save a-canela/4cbbe20b08ce1fa92ff373d5b60ac9ef to your computer and use it in GitHub Desktop.
Directory download (sync) through boto3 supporting include-only pattern
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import boto3 | |
from fnmatch import fnmatch | |
from botocore.config import Config | |
from traceback import format_exc | |
from multiprocessing import Process, Queue, Pipe | |
from concurrent.futures import ThreadPoolExecutor | |
class DownloadException(Exception): | |
pass | |
class S3Exception(Exception): | |
pass | |
class S3DownloadException(S3Exception, DownloadException): | |
pass | |
class ProcessWithExceptionPiping(Process): | |
"""Multiprocessing Process with exception piping""" | |
def __init__( | |
self, | |
*args, | |
raise_at_child: bool = False, | |
raise_at_parent: bool = True, | |
**kwargs, | |
): | |
"""Initialize process with exception piping. | |
Args: | |
raise_at_child (bool, optional): Raise exception at child process. Defaults to False. | |
raise_at_parent (bool, optional): Raise exception at parent process. Defaults to True. | |
""" | |
self.raise_at_child = raise_at_child | |
self.raise_at_parent = raise_at_parent | |
self.exception_traceback = None | |
self.parent_connection, self.child_connection = Pipe() | |
super().__init__(*args, **kwargs) | |
def run(self): | |
try: | |
super().run() | |
self.child_connection.send(None) | |
except BaseException as exception: | |
traceback = format_exc() | |
self.child_connection.send((exception, traceback)) | |
if self.raise_at_child: | |
raise exception | |
def join(self, *args, **kwargs): | |
super().join(*args, **kwargs) | |
if self.raise_at_parent: | |
exception_traceback = self.get_exception_traceback() | |
if exception_traceback: | |
exception, traceback = exception_traceback | |
raise type(exception)(f'{traceback}\n{exception}') | |
def get_exception_traceback(self): | |
if self.parent_connection.poll(): | |
self.exception_traceback = self.parent_connection.recv() | |
return self.exception_traceback | |
class S3Client: | |
"""AWS S3 boto3-based client for downloading/syncing folders""" | |
def __init__( | |
self, | |
bucket: str = '', | |
max_attempts: int = 10, | |
max_pool_connections: int = 100, | |
max_download_workers: int = 20, | |
) -> None: | |
"""AWS S3 boto3-based client constructor. | |
Note: Implementation only valid for UNIX. | |
Args: | |
bucket (str): Bucket path. | |
max_attempts (int, optional): Maximum retry attemps. | |
Defaults to 10 | |
max_pool_connections (int, optional) Maximum number of concurrent requests to aws s3. | |
Defaults to 100 | |
max_download_workers (int, optional) Maximum number of workers for downloading files. | |
Defaults to 20 | |
Attributes: | |
success_paths (list): List of tuples of successfully uploaded paths. | |
""" | |
self.success_paths = [] | |
self.bucket = bucket.replace('s3://', '') | |
self.max_attempts = max_attempts | |
config = Config(retries={ | |
'mode': 'standard', | |
'max_attempts': max_attempts, | |
}, max_pool_connections=max_pool_connections) | |
self.client = boto3.client('s3', config=config) | |
self.max_download_workers = max_download_workers | |
self.paginator = self.client.get_paginator('list_objects_v2') | |
def download_folder( | |
self, | |
dst_dir: str, | |
prefix: str, | |
pattern: str = None, | |
bucket: str = '', | |
) -> None: | |
"""Download a folder (recursively) from S3. | |
TODO: Benchmark this (Process + ThreadPoolExecutor) vs only ThreadPoolExecutor. | |
Test cases: small/large list (> 10000 files), need/no need to re-download. | |
Args: | |
dst_dir (str): Destination local folder. | |
prefix (str): File/folder S3 key. | |
pattern (str, optional): Including filter pattern. For more information, check | |
fnmatch library. Examples: '*/my_folder/image_???_*.png', '*.json'. | |
Defaults to None. | |
bucket (str, optional): File/folder S3 bucket. | |
Defaults to self.bucket. | |
""" | |
self.downloader_queue = Queue() | |
self.downloader = ProcessWithExceptionPiping(target=self._downloader) | |
self.downloader.start() | |
try: | |
dst_dir = dst_dir.rstrip('/') | |
prefix = prefix.rstrip('/') | |
bucket = bucket or self.bucket | |
self._enqueue_downloads( | |
dst_dir=dst_dir, | |
prefix=prefix, | |
pattern=pattern, | |
bucket=bucket, | |
) | |
finally: | |
self.downloader_queue.put(None) | |
self.downloader.join() | |
def _enqueue_downloads( | |
self, | |
dst_dir: str, | |
prefix: str, | |
pattern: str = None, | |
bucket: str = '', | |
) -> None: | |
"""Enqueue files to be downloaded by the downloader background process. | |
Args: | |
dst_dir (str): Destination local folder. | |
prefix (str): File/folder S3 key. | |
pattern (str, optional): Including filter pattern. Defaults to None. | |
bucket (str, optional): File/folder S3 bucket. Defaults to ''. | |
""" | |
root_dir = f'{dst_dir}/{prefix}' | |
for page in self.paginator.paginate(Bucket=bucket, Prefix=prefix): | |
files = page.get('Contents', ()) | |
file_data = [ | |
( | |
file.get('LastModified').timestamp(), | |
file.get('Size'), | |
file.get('Key'), | |
file.get('Key').replace(f'{prefix}/', ''), | |
) | |
for file in files | |
] | |
if pattern: | |
file_data = list(filter(lambda file: fnmatch(file[3], pattern), file_data)) | |
sub_dirs = set([os.path.dirname(file[3]) for file in file_data]) | |
for sub_dir in sub_dirs: | |
os.makedirs(f'{root_dir}/{sub_dir}', exist_ok=True) | |
for timestamp, size, key, sub_path in file_data: | |
self.downloader_queue.put((bucket, key, f'{root_dir}/{sub_path}', timestamp, size)) | |
def _downloader(self) -> None: | |
"""Downloader background process""" | |
futures = [] | |
with ThreadPoolExecutor(max_workers=self.max_download_workers) as executor: | |
args = self.downloader_queue.get() | |
while args: | |
futures.append(executor.submit(self._download_file, *args)) | |
args = self.downloader_queue.get() | |
for future in futures: | |
future.result() | |
def _download_file( | |
self, | |
bucket: str, | |
key: str, | |
dst_path: str, | |
timestamp: float, | |
size: int, | |
) -> None: | |
"""Download a file and assign the provided timestamp to it. | |
Args: | |
bucket (str): File S3 bucket. | |
key (str): File S3 key. | |
dst_path (str): File destination local path. | |
timestamp (float): File S3 timestamp. | |
size (int): File S3 size. | |
""" | |
if os.path.exists(dst_path): | |
dst_file_stat = os.stat(dst_path) | |
if timestamp == dst_file_stat.st_mtime and size == dst_file_stat.st_size: | |
return | |
self.client.download_file(bucket, key, dst_path) | |
if os.path.exists(dst_path): | |
os.utime(dst_path, (timestamp, timestamp)) | |
else: | |
raise S3DownloadException(f'Error downloading: s3://{bucket}/{key} => {dst_path}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment