Skip to content

Instantly share code, notes, and snippets.

@danielhavir
Last active March 31, 2020 19:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danielhavir/3f026a9a8c68ecc9f341431d0761b2a1 to your computer and use it in GitHub Desktop.
Save danielhavir/3f026a9a8c68ecc9f341431d0761b2a1 to your computer and use it in GitHub Desktop.
# Full Example: https://gist.github.com/danielhavir/407a6cfd592dfc2ad1e23a1ed3539e07
import os
from typing import Callable, List, Tuple, Generator, Dict
import torch
import torch.utils.data
from PIL.Image import Image as ImageType
def list_items_local(path: str) -> List[str]:
return sorted(os.path.splitext(f)[0] for f in os.listdir(path))
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, data_root: str, items: List[str], loader: Callable[[str], ImageType] = pil_loader, transform=None):
self.data_root = data_root
self.loader = loader
self.items = items
self.transform = transform
def __len__(self):
return len(self.items)
def __getitem__(self, item):
item_id = self.items[item]
image = self.loader(os.path.join(self.data_root, "images", item_id + ".jpg"))
label = self.loader(os.path.join(self.data_root, "labels", item_id + ".png"))
if self.transform is not None:
image, label = self.transform((image, label))
return image, label
def get_local_dataloaders(local_data_root: str, batch_size: int = 8, transform: Callable = None,
test_ratio: float = 0.1, num_workers: int = 8) -> Dict[str, torch.utils.data.DataLoader]:
# Local training
local_items = list_items_local(os.path.join(local_data_root, "images"))
dataset = ImageDataset(local_data_root, local_items, transform=transform)
# Split using consistent hashing
train_indices, test_indices = consistent_train_test_split(local_items, test_ratio)
return {
"train": torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=torch.utils.data.SubsetRandomSampler(train_indices),
num_workers=num_workers),
"test": torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=torch.utils.data.SubsetRandomSampler(test_indices),
num_workers=num_workers)
}
import io
from urllib.parse import urlparse
from PIL import Image
from google.cloud import storage
from google.api_core.retry import Retry
@Retry()
def gcs_pil_loader(uri: str) -> ImageType:
uri = urlparse(uri)
client = storage.Client()
bucket = client.get_bucket(uri.netloc)
b = bucket.blob(uri.path[1:], chunk_size=None)
image = Image.open(io.BytesIO(b.download_as_string()))
return image.convert("RGB")
@Retry()
def load_items_gcs(path: str) -> List[str]:
uri = urlparse(path)
client = storage.Client()
bucket = client.get_bucket(uri.netloc)
blobs = bucket.list_blobs(prefix=uri.path[1:], delimiter=None)
return sorted(os.path.splitext(os.path.basename(blob.name))[0] for blob in blobs)
def get_streamed_dataloaders(gcs_data_root: str, batch_size: int = 8, transform: Callable = None,
test_ratio: float = 0.1, num_workers: int = 8) -> Dict[str, torch.utils.data.DataLoader]:
# Streaming
streamed_items = load_items_gcs(os.path.join(gcs_data_root, "images"))
dataset = ImageDataset(gcs_data_root, streamed_items, loader=gcs_pil_loader, transform=transform)
# Identical for both local and streamed
# This is handy for CrossValidation, use consistent hashing
train_indices, test_indices = consistent_train_test_split(streamed_items, test_ratio)
return {
"train": torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=torch.utils.data.SubsetRandomSampler(train_indices),
num_workers=num_workers),
"test": torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=torch.utils.data.SubsetRandomSampler(test_indices),
num_workers=num_workers)
}
import random
import asyncio
import aiohttp
from janus import Queue
from gcloud.aio.storage import Storage
def generate_stream(items: List[str]) -> Generator[str, None, None]:
while True:
# Python's randint has inclusive upper bound
index = random.randint(0, len(items) - 1)
yield items[index]
class AsyncImageDataset(torch.utils.data.IterableDataset):
def __init__(self, data_root: str, items: List[str], transform: Callable = None, concurrency: int = 64):
self.data_root = data_root
self.items = items
self.transform = transform
self.worker_initialized = False
self.loop_thread = None
self.q = None
self.creds = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
self.concurrency = concurrency
self.stream = generate_stream(self.items)
async def run(self, loop, session):
for item in self.stream:
try:
image_gs = urlparse(os.path.join(self.data_root, "images", item + ".jpg"))
label_gs = urlparse(os.path.join(self.data_root, "labels", item + ".png"))
aio_storage = Storage(service_file=self.creds, session=session)
blobs = await asyncio.gather(
aio_storage.download(image_gs.netloc, image_gs.path[1:]),
aio_storage.download(label_gs.netloc, label_gs.path[1:]),
loop=loop
)
image = Image.open(io.BytesIO(blobs[0]))
label = Image.open(io.BytesIO(blobs[1])).convert("RGB")
await self.q.async_q.put((image, label))
except aiohttp.ClientError as e:
logging.debug(e)
except TimeoutError:
pass
except Exception as e:
logging.exception(e)
def init_worker(self):
loop = asyncio.new_event_loop()
session = aiohttp.ClientSession(loop=loop, connector=aiohttp.TCPConnector(limit=0, loop=loop),
raise_for_status=True)
self.q = Queue(self.concurrency, loop=loop)
# Spin up workers
for _ in range(self.concurrency):
loop.create_task(self.run(loop, session))
def loop_in_thread(loop):
asyncio.set_event_loop(loop)
loop.run_forever()
self.loop_thread = Thread(target=loop_in_thread, args=(loop,), daemon=True)
self.loop_thread.start()
self.worker_initialized = True
def __iter__(self):
while True:
if not self.worker_initialized:
self.init_worker()
image, label = self.q.sync_q.get()
if self.transform is not None:
image, label = self.transform((image, label))
yield image, label
def get_async_dataloaders(gcs_data_root: str, batch_size: int = 8, transform: Callable = None,
test_ratio: float = 0.1, num_workers: int = 8) -> Dict[str, torch.utils.data.DataLoader]:
# Async Streaming
streamed_items = load_items_gcs(os.path.join(gcs_data_root, "images"))
train_indices, test_indices = consistent_train_test_split(streamed_items, test_ratio)
train_items = [streamed_items[i.item()] for i in train_indices]
train_dataset = AsyncImageDataset(gcs_data_root, train_items, transform=transform, concurrency=128)
test_items = [streamed_items[i.item()] for i in test_indices]
test_dataset = AsyncImageDataset(gcs_data_root, test_items, transform=transform, concurrency=128)
return {
"train": torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, worker_init_fn=worker_init_fn,
num_workers=num_workers),
"test": torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, worker_init_fn=worker_init_fn,
num_workers=num_workers)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment