Created
April 25, 2023 23:01
-
-
Save halflearned/f39bd588455beb74462f060c6d819632 to your computer and use it in GitHub Desktop.
CIFAR-10 data sharding and upload
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import boto3 | |
import numpy as np | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.utils.data import Subset | |
def upload_to_s3(local_path, s3_path): | |
s3 = boto3.client('s3', region_name="us-west-2") | |
bucket, key = s3_path[5:].split('/', 1) | |
print("bucket", bucket) | |
print("key", key) | |
with open(local_path, 'rb') as f: | |
s3.upload_fileobj(f, bucket, key) | |
def download_cifar10(): | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True) | |
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True) | |
return trainset, testset | |
def split_into_subsets(dataset, num_shards): | |
num_samples = len(dataset) | |
indices = np.random.permutation(num_samples) | |
shard_sizes = np.full(num_shards, num_samples // num_shards) | |
shard_sizes[:num_samples % num_shards] += 1 | |
shards = [] | |
start = 0 | |
for shard_size in shard_sizes: | |
end = start + shard_size | |
shard_indices = indices[start:end] | |
shard = Subset(dataset, shard_indices) | |
shards.append(shard) | |
start = end | |
return shards | |
def main(): | |
# S3 location | |
s3_base_path = f's3://tritium-phase1-experiments/external-datasets/cifar10-sharded/' | |
# Download CIFAR10 dataset | |
trainset, testset = download_cifar10() | |
# Split the training data into 24 subsets | |
num_shards = 24 | |
shards = split_into_subsets(trainset, num_shards) | |
# Save and upload each of the training subsets | |
save_dir = './data/shards' | |
os.makedirs(save_dir, exist_ok=True) | |
for i, shard in enumerate(shards): | |
shard_data = [(shard.dataset[idx][0], shard.dataset[idx][1]) for idx in shard.indices] | |
shard_path = os.path.join(save_dir, f'shard-{i}.pt') | |
torch.save(shard_data, shard_path) | |
upload_to_s3(shard_path, f'{s3_base_path}train/train-shard-{i}.pt') | |
# Save test set as a (single) PT file and upload to S3 | |
test_path = './data/shards/test.pt' | |
torch.save(testset, test_path) | |
upload_to_s3(test_path, f'{s3_base_path}test/test.pt') | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment