Skip to content

Instantly share code, notes, and snippets.

@jguertl
Last active March 13, 2019 15:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jguertl/864de051fcc8812e7289312c3c95a261 to your computer and use it in GitHub Desktop.
Save jguertl/864de051fcc8812e7289312c3c95a261 to your computer and use it in GitHub Desktop.
class DataSet(gluon.data.Dataset):
def __init__(self,root,DomainAList, DomainBList):
self.root = root
self.DomainAList = DomainAList
self.DomainBList = DomainBList
self.load_images()
def read_images(self, root):
Aroot = root + 'trainA/' #data
Broot = root + 'trainB/' #label
A, B = [None] * len(self.DomainAList), [None] * len(self.DomainBList)
for i, name in enumerate(self.DomainAList):
A[i] = image.imread(Aroot + name)
for i,name in enumerate(self.DomainBList):
B[i] = image.imread(Broot + name)
return A, B
def load_images(self):
A, B = self.read_images(root=self.root)
self.A = [self.normalize_image(im) for im in A]
self.B = [self.normalize_image(im) for im in B]
def normalize_image(self, A):
return A.astype('float32') / 255
def __getitem__(self, item):
A = self.A[item]
B = self.B[item]
return A.transpose((2, 0, 1)), B.transpose((2, 0, 1))
def __len__(self):
return len(self.A)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment