Skip to content

Instantly share code, notes, and snippets.

@Miladiouss
Last active January 16, 2021 21:08
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Miladiouss/30c6a90da3243eafe19fb02acf0747b1 to your computer and use it in GitHub Desktop.
Save Miladiouss/30c6a90da3243eafe19fb02acf0747b1 to your computer and use it in GitHub Desktop.
Easy astronomy FITS file handling for Python. It includes easy cutout and save function. Additionally, a percentile normalization method is provided which is ideal for scaling FITS files to better visualization (similar to MinMax of DS9).
import numpy as np
from pathlib import Path
# AstroPy
import astropy
from astropy.coordinates import SkyCoord, GCRS, ICRS, GeocentricTrueEcliptic, Galactic
import astropy.units as u
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D
def sum2mag(pix_sum, zero_point = 27.00):
return -2.5 * np.log10(pix_sum) + zero_point
class FITS():
def __init__(self, fits_path, band = None, ext_idx = 1):
self.path = fits_path
self.band = band
self.file = fits.open(fits_path)
self.ext_idx = ext_idx
self.data = self.file[self.ext_idx].data
self.wcs = WCS(self.file[self.ext_idx].header)
self.dim = self.data.shape
def cutout(self, pos = (50, 50), size = (64, 64), save_to_path = None, scale_func = None, dtype = None, overwrite = True):
"""
pos can be in pixels (e.g. (50, 50)) or in RA and Dec (e.g. SkyCoord(+1.2, -3.4, unit="deg", frame="fk5"))
dtype: 'float32' or 'uint8'
"""
# Crop
if size:
cutout = Cutout2D(
data=self.data,
position = pos, #SkyCoord(row['ra'], row['dec'], unit="deg", frame="fk5"),
size = size,
wcs = self.wcs
)
output = cutout.data
# Add coord info to the header
output_wcs = cutout.wcs
else:
output = self.data
# Add coord info to the header
output_wcs = self.wcs
if scale_func is None:
scale_func = lambda x:x
output = scale_func(output)
# Change dtype
if dtype is None:
dtype = self.data.dtype
output = output.astype(np.dtype(dtype))
if save_to_path:
if save_to_path.suffix in ('.fits'):
# Create a new fits object and add cutout data
output_file = self.file # fits.PrimaryHDU()
output_file[self.ext_idx].data = output
# Update new FITS coords
output_file[self.ext_idx].header.update(output_wcs.to_header())
# Save the new fits file
output_file.writeto(save_to_path, overwrite=overwrite)
elif save_to_path.suffix in ('.jpeg', '.jpg', '.png'):
imsave(save_to_path, output)
else:
raise Exception('The file extension must be one of the following: .fits, .jpeg, .jpg, .png. "{}" was given instead.'.format(save_to_path.suffix))
return output
def close(self):
self.file.close()
def reduce2uint8(self, output_path, p_low, p_high, overwrite=True):
"""
Converts FITS to an png-like FITS and saves it to output_path.
"""
# Convert data
output_data = self.data
output_data = percentile_normalization(output_data, p_low_feed = p_low, p_high_feed = p_high, scale_coef = 255)
output_data = output_data.astype(np.dtype('uint8'))
# setup header with correct WCS
header = fits.getheader(self.path, self.ext_idx)
header.update(self.wcs.to_header())
header.p_low = p_low
header.p_high = p_high
header.scale_coef = 255
hdu = fits.PrimaryHDU(data=output_data, header=header)
# Save
hdul = fits.HDUList([hdu])
hdul.writeto(output_path, overwrite=overwrite)
def percentile_normalization(data, percentile_low = 1.5, percentile_high = 1.5, p_low_feed = None, p_high_feed = None, scale_coef = 1):
p_low = np.percentile(data, percentile_low)
p_high = np.percentile(data, 100 - percentile_high)
# Artificially set p_low and p_high
if p_low_feed:
p_low = p_low_feed
if p_high_feed:
p_high = p_high_feed
# Bound values between q_min and q_max
normalized = np.clip(data, p_low, p_high)
# Shift the zero to prevent negative vlaues
normalized = normalized - np.min(normalized)
# Normalize so the max is 1
normalized /= np.max(normalized)
# Scale
normalized *= scale_coef
return normalized
# ================================= Example 1 =================================
# Read a file and visualize a cutout
# x = FITS('path/to/file.fits')
# sf = lambda data: percentile_normalization(data, percentile_high=1., percentile_low=30)
# d = x.cutout(pos=(92, 110), scale_func=sf)
# ================================= Example 2 =================================
# # Example of reducing an HSC-PDR2 float32 FITS file to uint8
# from pathlib import Path
# from FITS_Handler import FITS, percentile_normalization
# # Define low and high percentile values for each filter
# i_low = -0.481
# i_high = 1.900
# r_low = -0.280
# r_high = 1.613
# g_low = -0.175
# g_high = 0.761
# # Define input and output paths
# inPath = Path('/HSC-Drive/HSC-PDR2/tracts/8279/calexp-HSC-R-8279-5,4.fits')
# outPath = Path('uint8_' + inPath.name)
# # Read, reduce, and close
# r = FITS(inPath)
# r.reduce2uint8(outPath, p_low=r_low, p_high=r_high)
# r.close()
# # Print sizes
# print("Original Size: {:3.0f} MB".format(inPath.stat().st_size / 1e6))
# print("Output Size: {:3.0f} MB".format(outPath.stat().st_size / 1e6))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment