Skip to content

Instantly share code, notes, and snippets.

@shuuchen
Last active September 10, 2018 03:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shuuchen/7e475b91df25fbea4997e2d17f5fedd0 to your computer and use it in GitHub Desktop.
Save shuuchen/7e475b91df25fbea4997e2d17f5fedd0 to your computer and use it in GitHub Desktop.
Pytorch でシーケンスデータを順番で読込 ref: https://qiita.com/shuuchen/items/466dc7977a146f7f38f2
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
# override the __getitem__ method. this is the method dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
data_dir = './pregnant'
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])
image_datasets = {x: ImageFolderWithPaths(os.path.join(data_dir, x),
transform=data_transforms) for x in ['all']}
data_loaders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size, shuffle=False) for x in ['all']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['all']}
for inputs, _, paths in data_loaders['all']:
print(paths)
break
'0.jpg',
'1.jpg',
'10.jpg',
'100.jpg',
'1000.jpg',
'1001.jpg',
'1002.jpg',
'1003.jpg',
'1004.jpg',
'1005.jpg',
'1006.jpg',
'1007.jpg'
...
from PIL import Image
data_iter = iter(data_loaders['all'])
# 本格
for i in range(1488 - batch_size):
imgs = []
for ii in range(i, i + batch_size):
path = os.path.join('{}.jpg'.format(ii))
print(path)
img = data_transforms(Image.open(path))
imgs.append(img)
print(len(imgs))
imgs = torch.stack(imgs)
print(imgs.size())
break
# 比較用
for inputs, _, paths in data_loaders['all']:
print(inputs.size())
break
0.jpg
1.jpg
2.jpg
3.jpg
4.jpg
5.jpg
6.jpg
7.jpg
8.jpg
...
36
torch.Size([36, 3, 224, 224])
torch.Size([36, 3, 224, 224])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment