Skip to content

Instantly share code, notes, and snippets.

@ddelange
Last active May 15, 2023 08:27
Show Gist options
  • Save ddelange/13b0f9da3147f3754b9e1e88c13303ba to your computer and use it in GitHub Desktop.
Save ddelange/13b0f9da3147f3754b9e1e88c13303ba to your computer and use it in GitHub Desktop.
Multithreaded S3 downloads
# pip install smart_open[s3]
from collections import deque
from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor
from functools import partial
from typing import Callable, Dict, Optional, Iterable, Iterator, Sequence
import boto3
import botocore
import smart_open
class URIDownloader:
"""Stream S3 URIs into memory using multithreading."""
def __init__(self, threads: int = 100):
self.threads = threads
self.executor = ThreadPoolExecutor(max_workers=threads)
config = botocore.client.Config(
max_pool_connections=threads,
tcp_keepalive=True,
retries={"max_attempts": 6, "mode": "adaptive"},
)
client = boto3.session.Session().client("s3", config=config) # thread-safe
self._open = partial(smart_open.open, transport_params={"client": client})
def read(self, uri: str, /, mode="rb", **kwargs) -> bytes:
"""Download (and decompress) a URI using smart_open."""
with self._open(uri, mode, **kwargs) as fp:
return fp.read()
def read_multi(self, uris: Iterable[str], **kwargs) -> Iterator[bytes]:
"""Download (and decompress) URIs with a multithreaded boto3 client."""
yield from self.executor.imap(partial(self.read, **kwargs), uris)
def read_multi_dict(self, uris: Sequence[str], **kwargs) -> Dict[str, bytes]:
"""Download (and decompress) URIs with a multithreaded boto3 client into a dict[uri, bytes]."""
return dict(zip(uris, self.read_multi(uris, **kwargs)))
class ThreadPoolExecutor(_ThreadPoolExecutor):
"""The missing ThreadPoolExecutor.imap ref https://gist.github.com/ddelange/c98b05437f80e4b16bf4fc20fde9c999."""
def imap(self, fn: Callable, *iterables: Iterable, timeout: Optional[float] = None):
"""Ordered imap that lazily consumes iterables."""
futures, maxlen = deque(), self._max_workers * 3 # two queued tasks per worker
popleft, append, submit = futures.popleft, futures.append, self.submit
def get():
"""Block until the next task is done and return the result."""
return popleft().result(timeout)
for args in zip(*iterables):
append(submit(fn, *args))
if len(futures) == maxlen:
yield get()
while futures:
yield get()
@ddelange
Copy link
Author

ddelange commented Feb 3, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment