Skip to content

Instantly share code, notes, and snippets.

@wtbarnes
Last active November 22, 2021 18:59
Show Gist options
  • Save wtbarnes/ca4417a55a695f1d604519ece0f538bc to your computer and use it in GitHub Desktop.
Save wtbarnes/ca4417a55a695f1d604519ece0f538bc to your computer and use it in GitHub Desktop.
Pipelining script for producing aligned AIA datacubes.
"""
Process images from level 1 FITS to level 1.5 aligned cutouts in Zarr
"""
import copy
import glob
import logging
import os
import traceback
import sunpy
import aiapy
import astropy
import aiapy.calibrate as ac
import aiapy.psf
from aiapy.util import sdo_location
from astropy.coordinates import SkyCoord
import astropy.time
import astropy.units as u
import astropy.wcs
import click
import eispac.core.eismap # this registers the EISMap class
import distributed
from reproject import reproject_interp
import sunpy.map
from sunpy.coordinates import Helioprojective, RotatedSunFrame, transform_with_sun_center
import zarr
def diff_rot_align(m_input, ref_center=None, rot_model='snodgrass'):
"""
Align an AIA map to a particular frame by compensating for differential rotation
Parameters
----------
m_input
AIA observation to rotate
observer
The observer location at the ``obstime`` to rotate to
ref_center
Center reference coordinates. These have to be explicitly specified so that
every map has the same WCS.
rot_model
"""
# Construct rotated meta frame
rot_frame = RotatedSunFrame(
base=ref_center.frame, rotated_time=m_input.date, rotation_model=rot_model)
# Construct new header
out_shape = m_input.data.shape
header = sunpy.map.make_fitswcs_header(
out_shape,
ref_center,
scale=u.Quantity(m_input.scale),
rotation_matrix=m_input.rotation_matrix,
instrument=m_input.meta['instrume'],
telescope=m_input.meta['telescop'],
wavelength=m_input.wavelength,
exposure=m_input.exposure_time,
)
# Reproject
out_wcs = astropy.wcs.WCS(header)
out_wcs.coordinate_frame = rot_frame
with transform_with_sun_center():
arr, _ = reproject_interp(m_input, out_wcs, out_shape)
return sunpy.map.Map(arr, header)
def update_meta(smap, smap_original):
"""
All remaining metadata modifications should be made here
"""
# Make sure any relevant missing metadata from the original map is preserved.
# We only preserved whitelisted keys to avoid any potential collisions with
# updated keys. Add more keys here as needed.
key_whitelist = [
'quality',
'bunit',
'dn_gain',
'eff_area',
'eff_ar_v',
]
for k in key_whitelist:
if k in smap_original.meta:
smap.meta[k] = smap_original.meta[k]
# NOTE: I want to encode the original time of the observation somewhere
# as the date-obs key will be overwritten with the obstime of the coordinate
# frame in the alignment step but I still need to preserve the original time.
# This is not ideal.
smap.meta['t_obs'] = smap_original.date.isot
# These data are no longer level 1 or 1.5. We'll call them level 2
# because they've been derotated and deconvolved.
smap.meta['lvl_num'] = 2
return smap
def map_to_zarr(smap, zarr_store=None):
root = zarr.open(store=zarr_store, mode='a')
name = f'{smap.wavelength.value:.0f}/{smap.meta["t_obs"]}'
ds = root.create_dataset(name, data=smap.data, chunks=smap.data.shape)
meta = copy.deepcopy(smap.meta)
problem_keys = ['history', 'comment'] # these can have dtypes that are not JSON serializable
for pk in problem_keys:
if pk in meta:
del meta[pk]
ds.attrs['meta'] = meta
return name
def cutout(smap, blc=None, trc=None):
bl = SkyCoord(*blc, frame=smap.coordinate_frame)
tr = SkyCoord(*trc, frame=smap.coordinate_frame)
return smap.submap(bl, top_right=tr)
def cutout_pad(smap, blc=None, trc=None):
"""
This does a padded cutout around the region so that the reprojection is faster.
This takes a full cut in the Tx dimension and pads the top and bottom cut in
height by 10%.
"""
height = trc[1] - blc[1]
blc_Ty = blc[1] - 0.1 * height
trc_Ty = trc[1] + 0.1 * height
bl = SkyCoord(smap.bottom_left_coord.Tx, blc_Ty, frame=smap.coordinate_frame)
tr = SkyCoord(smap.top_right_coord.Tx, trc_Ty, frame=smap.coordinate_frame)
return smap.submap(bl, top_right=tr)
@click.command()
@click.option('--fits-root', help='Directory containing level 1 FITS files')
@click.option('--zarr-root', help='Path to Zarr store to save the resulting data products.')
@click.option('--dask-scheduler-address', default=None, help='Dask scheduler to connect to.')
@click.option('--blc', nargs=2, type=float, default=None, help='Bottom left coordinate of the cutout.')
@click.option('--trc', nargs=2, type=float, default=None, help='Top right coordinate of the cutout.')
@click.option('--reference-time', type=str, default=None, help='Time to derotate the image to.')
@click.option('--eis-file', type=str, default=None,
help='Path to EIS file. The reference time and cutout coordinates can be derived from this.')
@click.option('--num-files', type=int, default=None, help='Debug option for only processing a few files.')
@click.option('--channels', type=str, default=None,
help='Which channels to process. If None, all 6 EUV channels will be processed.')
@click.option('--deconvolve', is_flag=True, default=False, help='If set, apply PSF deconvolution.')
@click.option('--correction-table', type=str, default=None, help='Path to correction table file.')
@click.option('--dry-run', is_flag=True, default=False, help='If set, the resulting files will not be saved.')
@click.option('--log-file', type=str, default=None,)
@click.option('--log-level', type=str, default='INFO',)
def cli(fits_root,
zarr_root,
dask_scheduler_address,
blc,
trc,
reference_time,
eis_file,
num_files,
channels,
deconvolve,
correction_table,
dry_run,
log_file,
log_level,
):
"""
Pipeline for processing level 1 AIA images into level PSF-deconvolved, level 1.5,
exposure time normalized, derotated, and cropped images in Zarr format.
"""
# TODO: add intermediate crop step so that we aren't differentially rotating
# full-disk images. This would differentially rotate the reference FOV to the
# unrotated frame, expand them through some padding in lat and lon, and then
# crop the map to that FOV. It will be more expanded in lon than lat
# Log configuration
logging.basicConfig(filename=log_file, level=log_level)
logging.info('AIA aligned cutout pipeline')
logging.info(f'sunpy v{sunpy.__version__}')
logging.info(f'astropy v{astropy.__version__}')
logging.info(f'aiapy v{aiapy.__version__}')
logging.info(f'dask-distributed v{distributed.__version__}')
# Get all of the FITS files, sorting by wavelength and time
logging.info(f'Processing level 1 FITS files in {fits_root}')
fits_format = 'aia.lev1_euv_12s.*.{channel}.image_lev1.fits'
if channels is None:
channels = ['94','131','171','193','211','335']
else:
channels = channels.split(',')
logging.info(f'Processing files from {channels} channels')
fits_files = {c: sorted(glob.glob(os.path.join(fits_root, fits_format.format(channel=c)))) for c in channels}
# Get corner coordinates and reference time
if eis_file is not None:
logging.debug(f'Using EIS file {eis_file} to determine reference time and bounding box')
m_eis = sunpy.map.Map(eis_file)
t_ref = m_eis.date_average
# Make the AIA cutout as wide and tall as the largest EIS dimension
width = m_eis.top_right_coord.Tx - m_eis.bottom_left_coord.Tx
height = m_eis.top_right_coord.Ty - m_eis.bottom_left_coord.Ty
fov = max(width, height)
blc = u.Quantity([m_eis.center.Tx, m_eis.center.Ty]) - fov/2
trc = blc + fov
elif blc is not None and trc is not None and reference_time is not None:
t_ref = astropy.time.Time(reference_time, format='isot', scale='utc')
# NOTE: These are in the HPC frame defined by SDO at reference_time
blc = blc * u.arcsec
trc = trc * u.arcsec
else:
raise ValueError('Must specify either an EIS map or reference time and cutout coordinates')
logging.info(f'Cutout bounding box blc: {blc}, trc: {trc}')
logging.info(f'Reference time {t_ref}')
# Retrieve pointing and correction tables for full time range
pointing_table = ac.util.get_pointing_table(t_ref-1*u.d, t_ref+1*u.d) # NOTE: requires remote connection
correction_table = ac.util.get_correction_table(correction_table=correction_table) # NOTE: requires remote connection
# Define reference observer to rotate to and construct reference coordinate at the center of the map
ref_observer = sdo_location(t_ref) # NOTE: requires remote connection
hpc_frame = Helioprojective(observer=ref_observer, obstime=ref_observer.obstime)
ref_center = SkyCoord(*(blc + trc)/2, frame=hpc_frame)
logging.info(f'Reference center coordinate: {ref_center}')
# Attach to scheduler
logging.info('Connecting to dask cluster')
client = distributed.Client(address=dask_scheduler_address)
logging.info(client)
logging.info(client.dashboard_link)
# Scatter some of our kwargs to be nicer to the scheduler
ptable_scatter = client.scatter(pointing_table)
ctable_scatter = client.scatter(correction_table)
# Compute PSFs
if deconvolve:
logging.info('Deconvolving with PSF')
psfs = {c: aiapy.psf.psf(int(c)*u.angstrom, use_gpu=True) for c in channels}
psfs_scattered = {k: client.scatter(v) for k,v in psfs.items()}
# Run pipeline
for c in channels:
futures = []
if num_files is None:
num_files = len(fits_files[c])
channel_files = fits_files[c][:num_files]
logging.debug(f'Processing {len(channel_files)} files')
for f in channel_files:
logging.debug(f'Processing {f}')
m_l1 = client.submit(sunpy.map.Map, f)
if deconvolve:
# Specifying resources here to avoid CUDA memory errors when running on the GPU
m_l1_d = client.submit(aiapy.psf.deconvolve, m_l1, use_gpu=True, psf=psfs_scattered[c],
resources={'GPU':1})
else:
m_l1_d = m_l1
m_pt = client.submit(ac.update_pointing, m_l1_d, pointing_table=ptable_scatter)
m_reg = client.submit(ac.register, m_pt)
m_cutout_pad = client.submit(cutout_pad, m_reg, blc=blc, trc=trc)
m_deg = client.submit(ac.correct_degradation, m_cutout_pad, correction_table=ctable_scatter)
m_norm = client.submit(lambda x: x / x.exposure_time, m_deg)
m_align = client.submit(diff_rot_align, m_norm, ref_center=ref_center)
m_cutout = client.submit(cutout, m_align, blc=blc, trc=trc)
m_meta_update = client.submit(update_meta, m_cutout, m_norm)
# If dry_run, skip the save step
if dry_run:
futures.append(m_meta_update)
else:
m_save = client.submit(map_to_zarr, m_meta_update, zarr_store=zarr_root)
futures.append(m_save)
logging.info(f'Processing {c} files...')
distributed.wait(futures)
logging.debug(futures)
for i,f in enumerate(futures):
fexc = f.exception()
if fexc is not None:
logging.exception(f'Error in file {channel_files[i]}')
logging.exception(fexc)
logging.exception(traceback.format_tb(f.traceback()))
if not dry_run:
n_L2 = len(zarr.open(store=zarr_root, mode='r')[c])
n_L1 = len(channel_files)
if n_L2 != n_L1:
logging.warn(f'Number of L2 files {n_L2} not equal to number of L1 files {n_L1}')
logging.debug(f'Number of L1 files for channel {c}: {n_L1}')
logging.debug(f'Number of L2 files for channel {c}: {n_L2}')
if __name__ == '__main__':
cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment