aug = iaa.Sharpen(alpha=(1.0), lightness=(1.5))

def adjust_data(img,mask):
    #img = img[:,:,1]
    #print(img.shape)
    img[img <0.2]=0.5
    img = img / 255
    mask = mask /255
    mask[mask > 0.5] = 1
    mask[mask <= 0.5] = 0
    
    return (img,mask)

class Dataset:
    # we will be modifying this CLASSES according to your data/problems
    
    # the parameters needs to changed based on your requirements
    # here we are collecting the file_names because in our dataset, both our images and maks will have same file name
    # ex: fil_name.jpg   file_name.mask.jpg
    def __init__(self, dataframe):
        
        self.ids = dataframe['Patient']
        # the paths of images
        self.images_fps   = dataframe['img']
        # the paths of segmentation images
        self.masks_fps    = dataframe['mask']
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i], cv2.IMREAD_UNCHANGED)
        image = aug.augment_image(image)
        image=image[:,:,1]
        image= np.reshape(image, (256,256,1))
        
        image = image.astype(np.float32)
        
        mask  = cv2.imread(self.masks_fps[i], cv2.IMREAD_UNCHANGED)
        mask = np.reshape(mask, (256,256,1))
        image_mask = mask
        image_mask = image_mask.astype(np.float32)
        
        image,image_mask= adjust_data(image, image_mask)
        return (image,image_mask)
        
    def __len__(self):
        return len(self.ids)
    
    
class Dataloder(tf.keras.utils.Sequence):    
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        batch = [np.stack(samples) for samples in zip(*data)]
        
        return tuple(batch)
    
    def __len__(self):
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)