Created
June 10, 2021 13:30
-
-
Save wtbarnes/8c1e8e8e39414784fa24cca3e697dfff to your computer and use it in GitHub Desktop.
Pipeline for creating aligned, level 1.5 cutouts from level 1 AIA images
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Process images from level 1 FITS to level 1.5 aligned cutouts in Zarr | |
""" | |
import copy | |
import glob | |
import os | |
import aiapy.calibrate as ac | |
import astropy.units as u | |
from astropy.coordinates import SkyCoord | |
import distributed | |
import numpy as np | |
from scipy.ndimage import shift | |
import sunpy.map | |
from sunpy.physics.differential_rotation import solar_rotate_coordinate | |
from sunpy.map.header_helper import get_observer_meta | |
import zarr | |
# Change this depending on the scheduler address | |
SCHEDULER_ADDRESS = 'tcp://127.0.0.1:34117' | |
# Change this depending on where you want to write out your aligned Zarr database | |
ZARR_L15_ROOT = '/path/to/L15_cutout_aligned.zarr' | |
# Change this depending on where your level 1 FITS data are | |
FITS_ROOT = '/path/to/L1_fits/' | |
# Index of map (in time) that you want to align to | |
REF_MAP_INDEX = 210 | |
# Coordinates of cutout in the HPC frame of the reference map specified above | |
CUTOUT_BLC = (-425, -200) * u.arcsec | |
CUTOUT_TRC = (100, 250) * u.arcsec | |
# Define a few functions we will need for prepping | |
def shift_align(smap, ref_map=None, rot_type='snodgrass'): | |
# NOTE: Alternatively, we could use the | |
# `~sunpy.physics.differential_rotate` but it is considerably slower. | |
# NOTE: This is a kwarg so that it plays nicely with client.map | |
if ref_map is None: | |
raise ValueError('Must provide a reference map.') | |
new_coord = solar_rotate_coordinate( | |
smap.center, observer=ref_map.observer_coordinate, rot_type=rot_type) | |
# Calculate shift | |
x_shift = (new_coord.Tx - ref_map.center.Tx)/smap.scale.axis1 | |
y_shift = (new_coord.Ty - ref_map.center.Ty)/smap.scale.axis2 | |
# TODO: implement in Dask | |
data_shifted = shift(smap.data, [y_shift.value, x_shift.value]) | |
# Update metadata | |
new_meta = copy.deepcopy(smap.meta) | |
if new_meta.get('date_obs', False): | |
del new_meta['date_obs'] | |
new_meta['date-obs'] = ref_map.observer_coordinate.obstime.isot | |
# Preserve the old time | |
new_meta['t_obs'] = smap.date.isot | |
new_meta.update(get_observer_meta(ref_map.observer_coordinate, | |
new_meta['rsun_ref'] * u.m)) | |
return smap._new_instance(data_shifted, new_meta) | |
def map_to_zarr(smap): | |
root = zarr.open(store=ZARR_L15_ROOT, mode='a') | |
name = f'{smap.wavelength.value:.0f}/{smap.meta["t_obs"]}' | |
ds = root.create_dataset(name, data=smap.data, chunks=smap.data.shape) | |
meta = copy.deepcopy(smap.meta) | |
problem_keys = ['history', 'comment'] # these can have dtypes that are not JSON serializable | |
for pk in problem_keys: | |
if pk in meta: | |
del meta[pk] | |
ds.attrs['meta'] = meta | |
return name | |
def cutout(smap): | |
bl = SkyCoord(*CUTOUT_BLC, frame=smap.coordinate_frame) | |
tr = SkyCoord(*CUTOUT_TRC, frame=smap.coordinate_frame) | |
return smap.submap(bl, top_right=tr) | |
# Get all of the FITS files, sorting by wavelength | |
fits_format = 'aia.lev1_euv_12s.*.{channel}.image_lev1.fits' | |
channels = ['94','131','171','193','211','335'] | |
fits_files = {} | |
for c in channels: | |
fits_files[c] = sorted(glob.glob(os.path.join(FITS_ROOT, fits_format.format(channel=c)))) | |
# Retrieve pointing and correction tables for full time range | |
t_start = sunpy.map.Map(fits_files['94'][0]).date - 3 * u.h | |
t_end = sunpy.map.Map(fits_files['94'][-1]).date + 3 * u.h | |
pointing_table = ac.util.get_pointing_table(t_start, t_end) | |
correction_table = ac.util.get_correction_table() | |
# Define reference map as the "middle" 171 observation | |
m_ref = sunpy.map.Map(fits_files['171'][REF_MAP_INDEX]) | |
m_ref = ac.register(ac.update_pointing(m_ref, pointing_table=pointing_table)) | |
# Run pipeline | |
client = distributed.Client(address=SCHEDULER_ADDRESS) | |
print(client) | |
print(client.dashboard_link) | |
# Scatter some of our kwargs to be nicer to the scheduler | |
m_ref_scatter = client.scatter(m_ref) | |
ptable_scatter = client.scatter(pointing_table) | |
ctable_scatter = client.scatter(correction_table) | |
for c in channels: | |
futures = [] | |
for f in fits_files[c]: | |
m_l1 = client.submit(sunpy.map.Map, f) | |
m_pt = client.submit(ac.update_pointing, m_l1, pointing_table=ptable_scatter) | |
m_reg = client.submit(ac.register, m_pt) | |
m_deg = client.submit(ac.correct_degradation, m_reg, correction_table=ctable_scatter) | |
m_norm = client.submit(ac.normalize_exposure, m_deg) | |
m_align = client.submit(shift_align, m_norm, ref_map=m_ref_scatter) | |
m_cutout = client.submit(cutout, m_align, pure=False) | |
futures.append(client.submit(map_to_zarr, m_cutout, pure=False)) | |
print(f'Processing {c} files...') | |
distributed.wait(futures) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment