Skip to content

Instantly share code, notes, and snippets.

@jxmorris12
Last active December 8, 2023 12:21
Show Gist options
  • Save jxmorris12/69a730fee174f5309968e984c298f8f2 to your computer and use it in GitHub Desktop.
Save jxmorris12/69a730fee174f5309968e984c298f8f2 to your computer and use it in GitHub Desktop.
map huggingface dataset with multiple workers using torch.dist
from typing import Callable
import shutil
import datasets
import torch
datasets.disable_caching()
cache_path = "/home/.cache"
def dataset_map_multi_worker(
dataset: datasets.Dataset, map_fn: Callable, *args, **kwargs
) -> datasets.Dataset:
try:
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
except (RuntimeError, ValueError):
return dataset.map(map_fn, *args, **kwargs)
ds_shard_filepaths = [
os.path.join(cache_path, f"{dataset._fingerprint}_subshard_{w}.cache")
for w in range(0, world_size)
]
print(f"\tworker {rank} saving sub-shard to {ds_shard_filepaths[rank]}")
ds_shard = dataset.shard(
num_shards=world_size,
index=rank,
contiguous=True,
)
ds_shard = ds_shard.map(map_fn, *args, **kwargs)
ds_shard.save_to_disk(ds_shard_filepaths[rank])
print("rank", rank, "saving:", ds_shard_filepaths[rank])
torch.distributed.barrier()
full_dataset = datasets.concatenate_datasets(
[datasets.load_from_disk(p) for p in ds_shard_filepaths]
)
torch.distributed.barrier()
print("rank", rank, "deleting:", ds_shard_filepaths[rank])
shutil.rmtree(ds_shard_filepaths[rank])
return full_dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment