Skip to content

Instantly share code, notes, and snippets.

@burrussmp
Last active May 5, 2020 19:08
Show Gist options
  • Save burrussmp/e6bcc6b6ac0d0da9ae1b6a818cb16f0e to your computer and use it in GitHub Desktop.
Save burrussmp/e6bcc6b6ac0d0da9ae1b6a818cb16f0e to your computer and use it in GitHub Desktop.
An example of a PyTorch Data Loader that uses ListDataset and a load_data function.
"""
Helper function for the Pytorch data loader
@params
type: string
Specifies if training (train), validation (valid), or testing (test) list
should be generated
@return
mlist: A nested Python list
A list of number of input-output pairs where each element is a list of size 2
The first element is the path to the .npy input file and the second element is
the path to the .npy of the one-hot-encoded segmentation map
"""
def get_list(type='train'):
assert type in ['train','valid','test'], \
print('Type must be train, valid, or test')
path = os.path.join(DATADIR,type)
num_items = len(os.listdir(path))
items = int(num_items/2)
mlist = []
for i in range(items):
path_to_img = os.path.join(path,'img_{}.npy'.format(i))
path_to_target = os.path.join(path,'target_{}.npy'.format(i))
mlist.append([path_to_img,path_to_target])
return mlist
"""
Pytorch Data loader and perform data augmentation
1. Randomly translates input
2. Draws random shape on input
3. Randomly changes the thinness of the input
4. Randomly changes the brightness
@params
y: Python list (len == 2)
y[0]: A path to a numpy array that contains the input HxWx1
y[1]: A path to a numpy array that contains the segmentation matrix HxWx(num_labels+1)
@return
Python dictionary
key: 'src' (nd.array)
The augmented input HxWx1
key: 'target' (nd.array)
The augmented target matrix one-hot-encoded HxWx(num_labels+1)
"""
def load_data(line):
path_to_image = line[0]
path_to_target = line[1]
img = np.load(path_to_image)
target = np.load(path_to_target)
smaller = np.squeeze(img)
img,target = random_translation(smaller,target) # random translation
src = draw_random_shape(img) # draw random shape
src = randomly_thin(src,p=0.35) # possible thin
src = randomly_change_brightness(src) # change brightness
src = np.expand_dims(src,axis=0) # because pytorch likes it like CxHxW
return {'src': src, 'target': target}
training_list = get_list('train')
training_dataset = ListDataset(training_list, load_data)
train_loader = DataLoader(dataset=training_dataset,batch_size=32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment