Last active
July 19, 2018 07:43
-
-
Save justusschock/0d80e5b01f07f388d79dc851a0358d14 to your computer and use it in GitHub Desktop.
Custom Dataloading
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UnalignedPairedData(object): | |
"""Class to combine two items of 2 datasets""" | |
def __init__(self, data_loader_a, data_loader_b, return_paths=True): | |
"""Function to initialize and create class variables""" | |
super(UnalignedPairedData, self).__init__() | |
self.return_paths = return_paths | |
self.dataLoaderA = data_loader_a | |
self.dataLoaderB = data_loader_b | |
self.dataLoaderAIter = None | |
self.dataLoaderBIter = None | |
self.stopA = False | |
self.stopB = False | |
self.iter = None | |
def __iter__(self): | |
""" | |
Function to iterate through datasets | |
:return: self | |
""" | |
self.stopA = False | |
self.stopB = False | |
self.dataLoaderAIter = iter(self.dataLoaderA) | |
self.dataLoaderBIter = iter(self.dataLoaderB) | |
self.iter = 0 | |
return self | |
def __next__(self): | |
""" | |
Function to get next items of datasets | |
:return: Dictionary containing the items | |
""" | |
if self.return_paths: | |
a, a_path = None, None | |
b, b_path = None, None | |
try: | |
a, a_path = next(self.dataLoaderAIter) | |
except StopIteration: | |
if a is None or a_path is None: | |
self.stopA = True | |
self.dataLoaderAIter = iter(self.dataLoaderA) | |
a, a_path = next(self.dataLoaderAIter) | |
try: | |
b, b_path = next(self.dataLoaderBIter) | |
except StopIteration: | |
if b is None or b_path is None: | |
self.stopB = True | |
self.dataLoaderBIter = iter(self.dataLoaderB) | |
b, b_path = next(self.dataLoaderBIter) | |
if self.stopA and self.stopB: | |
self.stopA = False | |
self.stopB = False | |
raise StopIteration() | |
else: | |
self.iter += 1 | |
return {'A': a, 'B': b, 'A_Path': a_path, 'B_Path': b_path} | |
else: | |
a = None | |
b = None | |
try: | |
a = next(self.dataLoaderAIter) | |
except StopIteration: | |
if a is None: | |
self.stopA = True | |
self.dataLoaderAIter = iter(self.dataLoaderA) | |
a = next(self.dataLoaderAIter) | |
try: | |
b = next(self.dataLoaderBIter) | |
except StopIteration: | |
if b is None: | |
self.stopB = True | |
self.dataLoaderBIter = iter(self.dataLoaderB) | |
b = next(self.dataLoaderBIter) | |
if self.stopA and self.stopB: | |
self.stopA = False | |
self.stopB = False | |
raise StopIteration() | |
else: | |
self.iter += 1 | |
return {'A': a, 'B': b} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment