Skip to content

Instantly share code, notes, and snippets.

@Sunmish
Created May 11, 2022 06:08
Show Gist options
  • Save Sunmish/beab0db9cafcd99922e343fc3e2879c6 to your computer and use it in GitHub Desktop.
Save Sunmish/beab0db9cafcd99922e343fc3e2879c6 to your computer and use it in GitHub Desktop.
Trim a tiled ASKAP image.
#! /usr/bin/env python
from argparse import ArgumentParser
import numpy as np
from astropy.io import fits
from scipy import ndimage
import logging
logging.basicConfig(format="%(levelname)s (%(module)s): %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def minimal_bounding_box(image_data):
"""
Obtain minimal bounding box for an image.
Based on https://stackoverflow.com/a/31402351
"""
arr = np.where(np.isfinite(image_data))
bbox = np.min(arr[0]), np.max(arr[0]), np.min(arr[1]), np.max(arr[1])
return bbox
def blank_islands(image_data):
"""Blank islands of pixels using scipy.ndimage.label."""
logger.debug("Retyping image...")
# arr = np.squeeze(image_data).copy().byteswap().newbyteorder().astype("float64")
arr = np.squeeze(image_data).copy().astype("float64")
# mask islands, remove nans
logger.debug("Masking...")
arr[np.where(np.isfinite(arr))] = 1.
arr[np.where(np.isnan(arr))] = 0.
logger.debug("Labelling array...")
source_image, _ = ndimage.label(arr)
sources = list(set(source_image.flatten()))
sources.remove(0)
logger.debug("Number of islands: {}".format(len(sources)))
if len(sources) > 1: #
logger.debug("Removing extra {} island(s)...".format(len(sources)-1))
# little islands detected!
lens = []
for island in sources:
lens.append(len(np.where(source_image.flatten() == island)[0]))
good_island_idx = np.argmax(lens)
for island in sources:
if island != sources[good_island_idx]:
idx = np.where(source_image == island)
image_data[..., idx[0], idx[1]] = np.nan
islands_blanked = True
else:
islands_blanked = False
return image_data, islands_blanked
def trim(image, outname,
overwrite=False,
factor=1.,
additional_images=None,
additional_factors=None,
exit_if_trim=False,
template=None):
"""Trim images."""
with fits.open(image) as f:
if template is not None:
t_header = fits.getheader(template)
if t_header["NAXIS1"] == f[0].header["NAXIS1"] and \
t_header["NAXIS2"] == f[0].header["NAXIS2"]:
if exit_if_trim:
logger.warn("Image already at template dimensions, exiting with no trimming.")
f.writeto(image.replace(".fits", ".{}".format(outname)),
overwrite=overwrite
)
return None
original_crpix = f[0].header["CRPIX1"], f[0].header["CRPIX2"]
# trim errant islands:
logger.debug("Blanking islands...")
f[0].data, islands_blanked = blank_islands(f[0].data)
if additional_images is not None:
logger.debug("Copying template...")
idx = np.where(np.isnan(f[0].data))
# template = np.squeeze(f[0].data.copy())
# constrict to a minimal rectangular bounding box:
logger.debug("Trimming excess image...")
bbox = minimal_bounding_box(np.squeeze(f[0].data))
f[0].data = f[0].data[..., bbox[0]:bbox[1], bbox[2]:bbox[3]]
# move reference pixels to match new array:
xdiff = original_crpix[0] - bbox[2]
ydiff = original_crpix[1] - bbox[0]
f[0].header["CRPIX1"] = xdiff
f[0].header["CRPIX2"] = ydiff
f[0].data /= factor
f.writeto(image.replace(".fits", ".{}".format(outname)),
overwrite=overwrite
)
# return idx
if additional_images is not None:
for i, additional_image in enumerate(additional_images):
logger.debug("Trimming {} based on template".format(additional_image))
# assume same header!
# a = fits.open(additional_image)
with fits.open(additional_image) as a:
a[0].data[idx] = np.nan
a[0].data = a[0].data[..., bbox[0]:bbox[1], bbox[2]:bbox[3]]
a[0].header["CRPIX1"] = xdiff
a[0].header["CRPIX2"] = ydiff
a[0].data /= additional_factors[i]
logger.debug("saving")
a.writeto(additional_image.replace(".fits", ".{}".format(outname)),
overwrite=overwrite
)
# a.close()
def cli():
"""
"""
description = "Trim ASKAP images - remove PB sidelobe islands and trim excess blank space."
ps = ArgumentParser(description=description)
ps.add_argument("image", help="Main image to trim.")
ps.add_argument("-o", "--outname", default="trim", help="Appended to output filename to avoid overwrites.")
ps.add_argument("-c", "--clobber", action="store_true", help="Overwrite anyway.")
ps.add_argument("-t", "--template", default=None, type=str, help="Template image to check dimensions.")
ps.add_argument("--divider", default=1., type=float, help="Divide image by this factor.")
ps.add_argument("--additional_images", nargs="*", default=None, help="List of additional images to trim, using the prime image as reference.")
ps.add_argument("--additional_dividers", nargs="*", type=float, default=[], help="List of additional division factors.")
ps.add_argument("--exit_if_trim", action="store_true", help="If template is supplied, exit if template and image dimensions match (image has probably already been trimmed...)")
args = ps.parse_args()
if args.additional_images is not None:
if len(args.additional_images) != len(args.additional_dividers):
logger.debug("Setting additional factors to 1")
args.additional_dividers = [1.]*len(args.additional_images)
trim(args.image, args.outname, overwrite=args.clobber, factor=args.divider,
additional_images=args.additional_images,
additional_factors=args.additional_dividers,
exit_if_trim=args.exit_if_trim,
template=args.template
)
if __name__ == "__main__":
cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment