Skip to content

Instantly share code, notes, and snippets.

@Tony363
Created August 6, 2023 03:55
Show Gist options
  • Save Tony363/8529f5b9309ac077941c80f72acd3a68 to your computer and use it in GitHub Desktop.
Save Tony363/8529f5b9309ac077941c80f72acd3a68 to your computer and use it in GitHub Desktop.
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict,load_from_disk
def process(
example:dict,
transform:torchvision.transforms
)->dict:
# print(example)
# for i,img in enumerate(example['image']):
# print(img)
# example['image'][i] = transform(img)
example['image'] = transform(example['image'])
return example
def main()->None:
normalize = transforms.Normalize(mean=[0],
std=[1])
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
ds = load_dataset('imagenet-1k')
ds_train = ds['train'].map(
# lambda example:process(example,train_transform),
# batched=True,
# writer_batch_size=len(ds['train'])//20
num_proc=os.cpu_count()//2,
)
ds_validation = ds['validation'].map(
# lambda example:process(example,test_transform),
# batched=True,
# writer_batch_size=len(ds['validation'])//20
num_proc=os.cpu_count()//2,
)
ds_test = ds['test'].map(
# lambda example:process(example,test_transform),
# batched=True,
# writer_batch_size=len(ds['test']),
num_proc=os.cpu_count()//2,
) # https://github.com/huggingface/datasets/issues/482
ds = DatasetDict({
'train':ds_train,
'validation':ds_validation,
'test':ds_test
})
ds.save_to_disk("imagenet-1k")
dataloader = DataLoader(ds['test'],batch_size=2)
for img in dataloader:
print(img.shape)
break
# print(ds[0][0])
return
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment