Skip to content

Instantly share code, notes, and snippets.

@azkalot1
Created April 25, 2019 11:36
Show Gist options
  • Save azkalot1/f00f6c7d34137a8170cb4517718f8f50 to your computer and use it in GitHub Desktop.
Save azkalot1/f00f6c7d34137a8170cb4517718f8f50 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.optim import Optimizer
from torch.utils import data
class DataGenerator(data.Dataset):
"""Generates dataset for loading.
Args:
ids: images ids
labels: labels of images (1/0)
augment: image augmentation from albumentations
imdir: path tpo folder with images
"""
def __init__(self, ids, labels, augment, imdir):
'Initialization'
self.ids, self.labels = ids, labels
self.augment = augment
self.imdir = imdir
def __len__(self):
return len(self.ids)
def __getitem__(self, idx):
imid = self.ids[idx]
y = self.labels[idx]
X = self.__load_image(imid)
return X, np.expand_dims(y,0)
def __load_image(self, imid):
imid = imid+'.tif'
im = imread(os.path.join(self.imdir, imid))
if self.augment!=None:
augmented = self.augment(image=im)
im = augmented['image']
im = im/255.0
im = np.rollaxis(im, -1)
return im
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment