Skip to content

Instantly share code, notes, and snippets.

@halflearned
Created April 25, 2023 23:01
Show Gist options
  • Save halflearned/f39bd588455beb74462f060c6d819632 to your computer and use it in GitHub Desktop.
Save halflearned/f39bd588455beb74462f060c6d819632 to your computer and use it in GitHub Desktop.
CIFAR-10 data sharding and upload
import os
import boto3
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
def upload_to_s3(local_path, s3_path):
s3 = boto3.client('s3', region_name="us-west-2")
bucket, key = s3_path[5:].split('/', 1)
print("bucket", bucket)
print("key", key)
with open(local_path, 'rb') as f:
s3.upload_fileobj(f, bucket, key)
def download_cifar10():
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
return trainset, testset
def split_into_subsets(dataset, num_shards):
num_samples = len(dataset)
indices = np.random.permutation(num_samples)
shard_sizes = np.full(num_shards, num_samples // num_shards)
shard_sizes[:num_samples % num_shards] += 1
shards = []
start = 0
for shard_size in shard_sizes:
end = start + shard_size
shard_indices = indices[start:end]
shard = Subset(dataset, shard_indices)
shards.append(shard)
start = end
return shards
def main():
# S3 location
s3_base_path = f's3://tritium-phase1-experiments/external-datasets/cifar10-sharded/'
# Download CIFAR10 dataset
trainset, testset = download_cifar10()
# Split the training data into 24 subsets
num_shards = 24
shards = split_into_subsets(trainset, num_shards)
# Save and upload each of the training subsets
save_dir = './data/shards'
os.makedirs(save_dir, exist_ok=True)
for i, shard in enumerate(shards):
shard_data = [(shard.dataset[idx][0], shard.dataset[idx][1]) for idx in shard.indices]
shard_path = os.path.join(save_dir, f'shard-{i}.pt')
torch.save(shard_data, shard_path)
upload_to_s3(shard_path, f'{s3_base_path}train/train-shard-{i}.pt')
# Save test set as a (single) PT file and upload to S3
test_path = './data/shards/test.pt'
torch.save(testset, test_path)
upload_to_s3(test_path, f'{s3_base_path}test/test.pt')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment