Skip to content

Instantly share code, notes, and snippets.

@yudai09
Created April 19, 2018 10:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yudai09/7be73fe2727734bf2bbf659f975eb6e2 to your computer and use it in GitHub Desktop.
Save yudai09/7be73fe2727734bf2bbf659f975eb6e2 to your computer and use it in GitHub Desktop.
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import time
from sklearn import datasets
import numpy
from PIL import Image
class ImageDataSet(Dataset):
"""Image dataset."""
def __init__(self, X=None, y=None):
self.number = range(100)
self.loader = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return 3000
def __getitem__(self, idx):
X = {}
# for i in range(10):
# # X["number_{}".format(i)] = self.loader(Image.open('example.png')).float()
# X["number_{}".format(i)] = numpy.asarray(Image.open('example.png'))
X = []
for i in range(10):
X.append(numpy.asarray(Image.open('example.png')) )
# image_path_set = self.X[idx]
# label = self._label_from_path(self.X[idx][0])
# images = {}
# for image_path in image_path_set:
# kind = self._kind_from_path(image_path)
# image = self._load_image(image_path)
# images[kind] = imag
label = numpy.expand_dims(0, axis=0)
label = torch.from_numpy(numpy.array(label))
return X, label
def main():
dataset = ImageDataSet()
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
for epoch in range(100):
for X, y in dataloader:
pass
# time.sleep(0.01)
print("epoch: {}".format(epoch))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment