Skip to content

Instantly share code, notes, and snippets.

@wtbarnes
Created June 10, 2021 13:30
Show Gist options
  • Save wtbarnes/8c1e8e8e39414784fa24cca3e697dfff to your computer and use it in GitHub Desktop.
Save wtbarnes/8c1e8e8e39414784fa24cca3e697dfff to your computer and use it in GitHub Desktop.
Pipeline for creating aligned, level 1.5 cutouts from level 1 AIA images
"""
Process images from level 1 FITS to level 1.5 aligned cutouts in Zarr
"""
import copy
import glob
import os
import aiapy.calibrate as ac
import astropy.units as u
from astropy.coordinates import SkyCoord
import distributed
import numpy as np
from scipy.ndimage import shift
import sunpy.map
from sunpy.physics.differential_rotation import solar_rotate_coordinate
from sunpy.map.header_helper import get_observer_meta
import zarr
# Change this depending on the scheduler address
SCHEDULER_ADDRESS = 'tcp://127.0.0.1:34117'
# Change this depending on where you want to write out your aligned Zarr database
ZARR_L15_ROOT = '/path/to/L15_cutout_aligned.zarr'
# Change this depending on where your level 1 FITS data are
FITS_ROOT = '/path/to/L1_fits/'
# Index of map (in time) that you want to align to
REF_MAP_INDEX = 210
# Coordinates of cutout in the HPC frame of the reference map specified above
CUTOUT_BLC = (-425, -200) * u.arcsec
CUTOUT_TRC = (100, 250) * u.arcsec
# Define a few functions we will need for prepping
def shift_align(smap, ref_map=None, rot_type='snodgrass'):
# NOTE: Alternatively, we could use the
# `~sunpy.physics.differential_rotate` but it is considerably slower.
# NOTE: This is a kwarg so that it plays nicely with client.map
if ref_map is None:
raise ValueError('Must provide a reference map.')
new_coord = solar_rotate_coordinate(
smap.center, observer=ref_map.observer_coordinate, rot_type=rot_type)
# Calculate shift
x_shift = (new_coord.Tx - ref_map.center.Tx)/smap.scale.axis1
y_shift = (new_coord.Ty - ref_map.center.Ty)/smap.scale.axis2
# TODO: implement in Dask
data_shifted = shift(smap.data, [y_shift.value, x_shift.value])
# Update metadata
new_meta = copy.deepcopy(smap.meta)
if new_meta.get('date_obs', False):
del new_meta['date_obs']
new_meta['date-obs'] = ref_map.observer_coordinate.obstime.isot
# Preserve the old time
new_meta['t_obs'] = smap.date.isot
new_meta.update(get_observer_meta(ref_map.observer_coordinate,
new_meta['rsun_ref'] * u.m))
return smap._new_instance(data_shifted, new_meta)
def map_to_zarr(smap):
root = zarr.open(store=ZARR_L15_ROOT, mode='a')
name = f'{smap.wavelength.value:.0f}/{smap.meta["t_obs"]}'
ds = root.create_dataset(name, data=smap.data, chunks=smap.data.shape)
meta = copy.deepcopy(smap.meta)
problem_keys = ['history', 'comment'] # these can have dtypes that are not JSON serializable
for pk in problem_keys:
if pk in meta:
del meta[pk]
ds.attrs['meta'] = meta
return name
def cutout(smap):
bl = SkyCoord(*CUTOUT_BLC, frame=smap.coordinate_frame)
tr = SkyCoord(*CUTOUT_TRC, frame=smap.coordinate_frame)
return smap.submap(bl, top_right=tr)
# Get all of the FITS files, sorting by wavelength
fits_format = 'aia.lev1_euv_12s.*.{channel}.image_lev1.fits'
channels = ['94','131','171','193','211','335']
fits_files = {}
for c in channels:
fits_files[c] = sorted(glob.glob(os.path.join(FITS_ROOT, fits_format.format(channel=c))))
# Retrieve pointing and correction tables for full time range
t_start = sunpy.map.Map(fits_files['94'][0]).date - 3 * u.h
t_end = sunpy.map.Map(fits_files['94'][-1]).date + 3 * u.h
pointing_table = ac.util.get_pointing_table(t_start, t_end)
correction_table = ac.util.get_correction_table()
# Define reference map as the "middle" 171 observation
m_ref = sunpy.map.Map(fits_files['171'][REF_MAP_INDEX])
m_ref = ac.register(ac.update_pointing(m_ref, pointing_table=pointing_table))
# Run pipeline
client = distributed.Client(address=SCHEDULER_ADDRESS)
print(client)
print(client.dashboard_link)
# Scatter some of our kwargs to be nicer to the scheduler
m_ref_scatter = client.scatter(m_ref)
ptable_scatter = client.scatter(pointing_table)
ctable_scatter = client.scatter(correction_table)
for c in channels:
futures = []
for f in fits_files[c]:
m_l1 = client.submit(sunpy.map.Map, f)
m_pt = client.submit(ac.update_pointing, m_l1, pointing_table=ptable_scatter)
m_reg = client.submit(ac.register, m_pt)
m_deg = client.submit(ac.correct_degradation, m_reg, correction_table=ctable_scatter)
m_norm = client.submit(ac.normalize_exposure, m_deg)
m_align = client.submit(shift_align, m_norm, ref_map=m_ref_scatter)
m_cutout = client.submit(cutout, m_align, pure=False)
futures.append(client.submit(map_to_zarr, m_cutout, pure=False))
print(f'Processing {c} files...')
distributed.wait(futures)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment