Last active
May 16, 2018 11:31
-
-
Save justusschock/dd28afb1dc5ea8f313a10bd25a001802 to your computer and use it in GitHub Desktop.
dataloader for FITS data files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils import data as data | |
from torchvision import transforms | |
import os | |
from astropy.io import fits | |
from skimage.transform import resize | |
IMG_EXTENSIONS = [ | |
".fits" | |
] | |
def is_image_file(filename): | |
""" | |
Helper Function to determine whether a file is an image file or not | |
:param filename: the filename containing a possible image | |
:return: True if file is image file, False otherwise | |
""" | |
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | |
def make_dataset(dir): | |
""" | |
Helper Function to make a dataset containing all images in a certain directory | |
:param dir: the directory containing the dataset | |
:return: images: list of image paths | |
""" | |
images = [] | |
assert os.path.isdir(dir), '%s is not a valid directory' % dir | |
for root, _, fnames in sorted(os.walk(dir)): | |
for fname in fnames: | |
if is_image_file(fname): | |
path = os.path.join(root, fname) | |
images.append(path) | |
return images | |
def default_fits_loader(file_name: str, img_size: tuple): | |
file = fits.open(file_name) | |
_data = file[1].data | |
_data = resize(_data, img_size) | |
# _data = fits.get_data(file_name).resize(img_size) | |
# add channels | |
if len(_data.shape) < 3: | |
_data = _data.reshape((*_data.shape, 1)) | |
# TODO: Insert Custom Label Loader | |
_label = _data | |
return _data, _label | |
class FITSDataset(data.Dataset): | |
def __init__(self, data_path, transforms, img_size): | |
self.data_path = data_path | |
self.transforms = transforms | |
self.img_size = img_size | |
self.img_files = make_dataset(data_path) | |
def __getitem__(self, index): | |
_img, _label = default_fits_loader(self.img_files[index], self.img_size) | |
if self.transforms is not None: | |
_data = (self.transforms(_img), _label) | |
else: | |
_data = (_img, _label) | |
return _data | |
def __len__(self): | |
return len(self.img_files) | |
if __name__ == '__main__': | |
from matplotlib import pyplot as plt | |
dataset = FITSDataset("PATH_TO_DATA_DIR", transforms.ToTensor(), (64, 64)) | |
data_loader = data.DataLoader(dataset, batch_size=2) | |
for idx, tmp in enumerate(data_loader): | |
print(idx) | |
_data_tensor, _label_tensor = tmp[0], tmp[1] | |
plt.imsave(os.path.join("PATH_TO_SAVE_DIR", "FILENAME"), _data_tensor.squeeze().numpy()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment