Last active
November 22, 2021 18:59
-
-
Save wtbarnes/ca4417a55a695f1d604519ece0f538bc to your computer and use it in GitHub Desktop.
Pipelining script for producing aligned AIA datacubes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Process images from level 1 FITS to level 1.5 aligned cutouts in Zarr | |
""" | |
import copy | |
import glob | |
import logging | |
import os | |
import traceback | |
import sunpy | |
import aiapy | |
import astropy | |
import aiapy.calibrate as ac | |
import aiapy.psf | |
from aiapy.util import sdo_location | |
from astropy.coordinates import SkyCoord | |
import astropy.time | |
import astropy.units as u | |
import astropy.wcs | |
import click | |
import eispac.core.eismap # this registers the EISMap class | |
import distributed | |
from reproject import reproject_interp | |
import sunpy.map | |
from sunpy.coordinates import Helioprojective, RotatedSunFrame, transform_with_sun_center | |
import zarr | |
def diff_rot_align(m_input, ref_center=None, rot_model='snodgrass'): | |
""" | |
Align an AIA map to a particular frame by compensating for differential rotation | |
Parameters | |
---------- | |
m_input | |
AIA observation to rotate | |
observer | |
The observer location at the ``obstime`` to rotate to | |
ref_center | |
Center reference coordinates. These have to be explicitly specified so that | |
every map has the same WCS. | |
rot_model | |
""" | |
# Construct rotated meta frame | |
rot_frame = RotatedSunFrame( | |
base=ref_center.frame, rotated_time=m_input.date, rotation_model=rot_model) | |
# Construct new header | |
out_shape = m_input.data.shape | |
header = sunpy.map.make_fitswcs_header( | |
out_shape, | |
ref_center, | |
scale=u.Quantity(m_input.scale), | |
rotation_matrix=m_input.rotation_matrix, | |
instrument=m_input.meta['instrume'], | |
telescope=m_input.meta['telescop'], | |
wavelength=m_input.wavelength, | |
exposure=m_input.exposure_time, | |
) | |
# Reproject | |
out_wcs = astropy.wcs.WCS(header) | |
out_wcs.coordinate_frame = rot_frame | |
with transform_with_sun_center(): | |
arr, _ = reproject_interp(m_input, out_wcs, out_shape) | |
return sunpy.map.Map(arr, header) | |
def update_meta(smap, smap_original): | |
""" | |
All remaining metadata modifications should be made here | |
""" | |
# Make sure any relevant missing metadata from the original map is preserved. | |
# We only preserved whitelisted keys to avoid any potential collisions with | |
# updated keys. Add more keys here as needed. | |
key_whitelist = [ | |
'quality', | |
'bunit', | |
'dn_gain', | |
'eff_area', | |
'eff_ar_v', | |
] | |
for k in key_whitelist: | |
if k in smap_original.meta: | |
smap.meta[k] = smap_original.meta[k] | |
# NOTE: I want to encode the original time of the observation somewhere | |
# as the date-obs key will be overwritten with the obstime of the coordinate | |
# frame in the alignment step but I still need to preserve the original time. | |
# This is not ideal. | |
smap.meta['t_obs'] = smap_original.date.isot | |
# These data are no longer level 1 or 1.5. We'll call them level 2 | |
# because they've been derotated and deconvolved. | |
smap.meta['lvl_num'] = 2 | |
return smap | |
def map_to_zarr(smap, zarr_store=None): | |
root = zarr.open(store=zarr_store, mode='a') | |
name = f'{smap.wavelength.value:.0f}/{smap.meta["t_obs"]}' | |
ds = root.create_dataset(name, data=smap.data, chunks=smap.data.shape) | |
meta = copy.deepcopy(smap.meta) | |
problem_keys = ['history', 'comment'] # these can have dtypes that are not JSON serializable | |
for pk in problem_keys: | |
if pk in meta: | |
del meta[pk] | |
ds.attrs['meta'] = meta | |
return name | |
def cutout(smap, blc=None, trc=None): | |
bl = SkyCoord(*blc, frame=smap.coordinate_frame) | |
tr = SkyCoord(*trc, frame=smap.coordinate_frame) | |
return smap.submap(bl, top_right=tr) | |
def cutout_pad(smap, blc=None, trc=None): | |
""" | |
This does a padded cutout around the region so that the reprojection is faster. | |
This takes a full cut in the Tx dimension and pads the top and bottom cut in | |
height by 10%. | |
""" | |
height = trc[1] - blc[1] | |
blc_Ty = blc[1] - 0.1 * height | |
trc_Ty = trc[1] + 0.1 * height | |
bl = SkyCoord(smap.bottom_left_coord.Tx, blc_Ty, frame=smap.coordinate_frame) | |
tr = SkyCoord(smap.top_right_coord.Tx, trc_Ty, frame=smap.coordinate_frame) | |
return smap.submap(bl, top_right=tr) | |
@click.command() | |
@click.option('--fits-root', help='Directory containing level 1 FITS files') | |
@click.option('--zarr-root', help='Path to Zarr store to save the resulting data products.') | |
@click.option('--dask-scheduler-address', default=None, help='Dask scheduler to connect to.') | |
@click.option('--blc', nargs=2, type=float, default=None, help='Bottom left coordinate of the cutout.') | |
@click.option('--trc', nargs=2, type=float, default=None, help='Top right coordinate of the cutout.') | |
@click.option('--reference-time', type=str, default=None, help='Time to derotate the image to.') | |
@click.option('--eis-file', type=str, default=None, | |
help='Path to EIS file. The reference time and cutout coordinates can be derived from this.') | |
@click.option('--num-files', type=int, default=None, help='Debug option for only processing a few files.') | |
@click.option('--channels', type=str, default=None, | |
help='Which channels to process. If None, all 6 EUV channels will be processed.') | |
@click.option('--deconvolve', is_flag=True, default=False, help='If set, apply PSF deconvolution.') | |
@click.option('--correction-table', type=str, default=None, help='Path to correction table file.') | |
@click.option('--dry-run', is_flag=True, default=False, help='If set, the resulting files will not be saved.') | |
@click.option('--log-file', type=str, default=None,) | |
@click.option('--log-level', type=str, default='INFO',) | |
def cli(fits_root, | |
zarr_root, | |
dask_scheduler_address, | |
blc, | |
trc, | |
reference_time, | |
eis_file, | |
num_files, | |
channels, | |
deconvolve, | |
correction_table, | |
dry_run, | |
log_file, | |
log_level, | |
): | |
""" | |
Pipeline for processing level 1 AIA images into level PSF-deconvolved, level 1.5, | |
exposure time normalized, derotated, and cropped images in Zarr format. | |
""" | |
# TODO: add intermediate crop step so that we aren't differentially rotating | |
# full-disk images. This would differentially rotate the reference FOV to the | |
# unrotated frame, expand them through some padding in lat and lon, and then | |
# crop the map to that FOV. It will be more expanded in lon than lat | |
# Log configuration | |
logging.basicConfig(filename=log_file, level=log_level) | |
logging.info('AIA aligned cutout pipeline') | |
logging.info(f'sunpy v{sunpy.__version__}') | |
logging.info(f'astropy v{astropy.__version__}') | |
logging.info(f'aiapy v{aiapy.__version__}') | |
logging.info(f'dask-distributed v{distributed.__version__}') | |
# Get all of the FITS files, sorting by wavelength and time | |
logging.info(f'Processing level 1 FITS files in {fits_root}') | |
fits_format = 'aia.lev1_euv_12s.*.{channel}.image_lev1.fits' | |
if channels is None: | |
channels = ['94','131','171','193','211','335'] | |
else: | |
channels = channels.split(',') | |
logging.info(f'Processing files from {channels} channels') | |
fits_files = {c: sorted(glob.glob(os.path.join(fits_root, fits_format.format(channel=c)))) for c in channels} | |
# Get corner coordinates and reference time | |
if eis_file is not None: | |
logging.debug(f'Using EIS file {eis_file} to determine reference time and bounding box') | |
m_eis = sunpy.map.Map(eis_file) | |
t_ref = m_eis.date_average | |
# Make the AIA cutout as wide and tall as the largest EIS dimension | |
width = m_eis.top_right_coord.Tx - m_eis.bottom_left_coord.Tx | |
height = m_eis.top_right_coord.Ty - m_eis.bottom_left_coord.Ty | |
fov = max(width, height) | |
blc = u.Quantity([m_eis.center.Tx, m_eis.center.Ty]) - fov/2 | |
trc = blc + fov | |
elif blc is not None and trc is not None and reference_time is not None: | |
t_ref = astropy.time.Time(reference_time, format='isot', scale='utc') | |
# NOTE: These are in the HPC frame defined by SDO at reference_time | |
blc = blc * u.arcsec | |
trc = trc * u.arcsec | |
else: | |
raise ValueError('Must specify either an EIS map or reference time and cutout coordinates') | |
logging.info(f'Cutout bounding box blc: {blc}, trc: {trc}') | |
logging.info(f'Reference time {t_ref}') | |
# Retrieve pointing and correction tables for full time range | |
pointing_table = ac.util.get_pointing_table(t_ref-1*u.d, t_ref+1*u.d) # NOTE: requires remote connection | |
correction_table = ac.util.get_correction_table(correction_table=correction_table) # NOTE: requires remote connection | |
# Define reference observer to rotate to and construct reference coordinate at the center of the map | |
ref_observer = sdo_location(t_ref) # NOTE: requires remote connection | |
hpc_frame = Helioprojective(observer=ref_observer, obstime=ref_observer.obstime) | |
ref_center = SkyCoord(*(blc + trc)/2, frame=hpc_frame) | |
logging.info(f'Reference center coordinate: {ref_center}') | |
# Attach to scheduler | |
logging.info('Connecting to dask cluster') | |
client = distributed.Client(address=dask_scheduler_address) | |
logging.info(client) | |
logging.info(client.dashboard_link) | |
# Scatter some of our kwargs to be nicer to the scheduler | |
ptable_scatter = client.scatter(pointing_table) | |
ctable_scatter = client.scatter(correction_table) | |
# Compute PSFs | |
if deconvolve: | |
logging.info('Deconvolving with PSF') | |
psfs = {c: aiapy.psf.psf(int(c)*u.angstrom, use_gpu=True) for c in channels} | |
psfs_scattered = {k: client.scatter(v) for k,v in psfs.items()} | |
# Run pipeline | |
for c in channels: | |
futures = [] | |
if num_files is None: | |
num_files = len(fits_files[c]) | |
channel_files = fits_files[c][:num_files] | |
logging.debug(f'Processing {len(channel_files)} files') | |
for f in channel_files: | |
logging.debug(f'Processing {f}') | |
m_l1 = client.submit(sunpy.map.Map, f) | |
if deconvolve: | |
# Specifying resources here to avoid CUDA memory errors when running on the GPU | |
m_l1_d = client.submit(aiapy.psf.deconvolve, m_l1, use_gpu=True, psf=psfs_scattered[c], | |
resources={'GPU':1}) | |
else: | |
m_l1_d = m_l1 | |
m_pt = client.submit(ac.update_pointing, m_l1_d, pointing_table=ptable_scatter) | |
m_reg = client.submit(ac.register, m_pt) | |
m_cutout_pad = client.submit(cutout_pad, m_reg, blc=blc, trc=trc) | |
m_deg = client.submit(ac.correct_degradation, m_cutout_pad, correction_table=ctable_scatter) | |
m_norm = client.submit(lambda x: x / x.exposure_time, m_deg) | |
m_align = client.submit(diff_rot_align, m_norm, ref_center=ref_center) | |
m_cutout = client.submit(cutout, m_align, blc=blc, trc=trc) | |
m_meta_update = client.submit(update_meta, m_cutout, m_norm) | |
# If dry_run, skip the save step | |
if dry_run: | |
futures.append(m_meta_update) | |
else: | |
m_save = client.submit(map_to_zarr, m_meta_update, zarr_store=zarr_root) | |
futures.append(m_save) | |
logging.info(f'Processing {c} files...') | |
distributed.wait(futures) | |
logging.debug(futures) | |
for i,f in enumerate(futures): | |
fexc = f.exception() | |
if fexc is not None: | |
logging.exception(f'Error in file {channel_files[i]}') | |
logging.exception(fexc) | |
logging.exception(traceback.format_tb(f.traceback())) | |
if not dry_run: | |
n_L2 = len(zarr.open(store=zarr_root, mode='r')[c]) | |
n_L1 = len(channel_files) | |
if n_L2 != n_L1: | |
logging.warn(f'Number of L2 files {n_L2} not equal to number of L1 files {n_L1}') | |
logging.debug(f'Number of L1 files for channel {c}: {n_L1}') | |
logging.debug(f'Number of L2 files for channel {c}: {n_L2}') | |
if __name__ == '__main__': | |
cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment