Skip to content

Instantly share code, notes, and snippets.

@justusschock
Last active May 16, 2018 11:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save justusschock/dd28afb1dc5ea8f313a10bd25a001802 to your computer and use it in GitHub Desktop.
Save justusschock/dd28afb1dc5ea8f313a10bd25a001802 to your computer and use it in GitHub Desktop.
dataloader for FITS data files
from torch.utils import data as data
from torchvision import transforms
import os
from astropy.io import fits
from skimage.transform import resize
IMG_EXTENSIONS = [
".fits"
]
def is_image_file(filename):
"""
Helper Function to determine whether a file is an image file or not
:param filename: the filename containing a possible image
:return: True if file is image file, False otherwise
"""
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
"""
Helper Function to make a dataset containing all images in a certain directory
:param dir: the directory containing the dataset
:return: images: list of image paths
"""
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_fits_loader(file_name: str, img_size: tuple):
file = fits.open(file_name)
_data = file[1].data
_data = resize(_data, img_size)
# _data = fits.get_data(file_name).resize(img_size)
# add channels
if len(_data.shape) < 3:
_data = _data.reshape((*_data.shape, 1))
# TODO: Insert Custom Label Loader
_label = _data
return _data, _label
class FITSDataset(data.Dataset):
def __init__(self, data_path, transforms, img_size):
self.data_path = data_path
self.transforms = transforms
self.img_size = img_size
self.img_files = make_dataset(data_path)
def __getitem__(self, index):
_img, _label = default_fits_loader(self.img_files[index], self.img_size)
if self.transforms is not None:
_data = (self.transforms(_img), _label)
else:
_data = (_img, _label)
return _data
def __len__(self):
return len(self.img_files)
if __name__ == '__main__':
from matplotlib import pyplot as plt
dataset = FITSDataset("PATH_TO_DATA_DIR", transforms.ToTensor(), (64, 64))
data_loader = data.DataLoader(dataset, batch_size=2)
for idx, tmp in enumerate(data_loader):
print(idx)
_data_tensor, _label_tensor = tmp[0], tmp[1]
plt.imsave(os.path.join("PATH_TO_SAVE_DIR", "FILENAME"), _data_tensor.squeeze().numpy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment