Created
October 27, 2021 23:23
-
-
Save wtbarnes/f04fe7b4a4c38c05e87d99f5074032d1 to your computer and use it in GitHub Desktop.
NDCube subclasses for working with stacked and aligned AIA data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Utilities for working with aligned AIA data cubes | |
""" | |
import copy | |
from os import uname | |
import warnings | |
import astropy.time | |
import astropy.wcs | |
import astropy.units as u | |
from astropy.visualization import ImageNormalize, AsinhStretch, AsymmetricPercentileInterval | |
from astropy.wcs.wcs import WCS | |
import cupy | |
import distributed | |
import dask.array | |
import ndcube | |
import numpy as np | |
from scipy.interpolate import interp1d | |
import sunkit_image.time_lag | |
import sunkit_image.enhance | |
import sunpy.map | |
import zarr | |
__all__ = [''] | |
def map_from_zarr(filename, name): | |
root = zarr.open(store=filename, mode='r') | |
ds = root[name] | |
data = dask.array.from_zarr(ds) | |
meta = ds.attrs['meta'] | |
return sunpy.map.Map(data, meta) | |
def fits_to_zarr(filename, root=None, chunks=None): | |
smap = sunpy.map.Map(filename) | |
root = zarr.open(store=root, mode='a') | |
name = f'{smap.wavelength.value:.0f}/{smap.date.isot}' | |
if isinstance(smap.data, dask.array.Array): | |
dask.array.to_zarr(smap.data, root, component=name) | |
ds = root[name] | |
else: | |
chunks = smap.data.shape if chunks is None else chunks | |
ds = root.create_dataset(name, data=smap.data, chunks=chunks) | |
meta = copy.deepcopy(smap.meta) | |
problem_keys = ['history', 'comment'] # these can have dtypes that are not JSON serializable | |
for pk in problem_keys: | |
if pk in meta: | |
del meta[pk] | |
ds.attrs['meta'] = meta | |
return name | |
def get_zarr_keys(root, channel): | |
root = zarr.open(root, mode='r') | |
return list(root[channel].array_keys()) | |
def get_aia_collection_from_level_2_zarr(zarr_store, chunks=None, exclude_masked=False): | |
""" | |
Given a Zarr store containing a bunch of aligned maps, create | |
a collection of cubes, one for each channel. | |
Parameters | |
---------- | |
zarr_store | |
chunks | |
""" | |
root = zarr.open(zarr_store, mode='r') | |
# Get list of channels | |
channels = list(root.group_keys()) | |
# For every channel, get all maps-->dictionary of map lists | |
# Turn each list of maps into a cube | |
cube_list = [] | |
for c in channels: | |
times = sorted(list(root[c].array_keys())) | |
map_list = [] | |
for t in times: | |
ds = root[f'{c}/{t}'] | |
map_list.append(sunpy.map.Map(dask.array.from_zarr(ds), ds.attrs['meta'])) | |
cube_list.append((c, AIACube.from_map_list(map_list))) | |
return AIACollection.from_uninterpolated_cubes(cube_list, chunks=chunks, exclude_masked=exclude_masked) | |
class AIACube(ndcube.NDCube): | |
def _new_instance(self, data, **kwargs): | |
""" | |
Given some alternate data representation, create another cube. | |
""" | |
if 'mask' not in kwargs: | |
mask = self.mask | |
else: | |
mask = kwargs['mask'] | |
return type(self)(data, self.wcs, meta=self.meta, unit=self.unit, mask=mask) | |
@property | |
def coordinate_frame(self): | |
return astropy.wcs.utils.wcs_to_celestial_frame(self.wcs) | |
@property | |
def observer_coordinate(self): | |
return self.coordinate_frame.observer | |
@property | |
def low_level_wcs(self): | |
# This is a bad hack that we probably shouldn't be doing | |
# Annoyingly, once we slice, it is hard to get at a mutable | |
# FITS WCS. There has to be a better way of doing this... | |
if isinstance(self.wcs.low_level_wcs, astropy.wcs.wcsapi.SlicedLowLevelWCS): | |
return self.wcs.low_level_wcs._wcs.wcs | |
else: | |
return self.wcs.low_level_wcs.wcs | |
def _slice_to_map(self, index): | |
# NOTE: this only works on a 3D cube with (time, lat, lon) | |
new_cube = self[index, :, :] | |
return sunpy.map.Map(new_cube.data, astropy.wcs.WCS(new_cube.low_level_wcs.to_header())) | |
@property | |
@u.quantity_input | |
def wavelength(self) -> u.angstrom: | |
return u.Quantity(self.meta[0]['wavelnth'], self.meta[0]['waveunit']) | |
@property | |
def _wavelength_label(self): | |
return f'{self.wavelength.value:.0f}' | |
@property | |
def _dates(self): | |
# NOTE: this is wrong as soon as the cube is sliced | |
warnings.warn('If your data and metadata axes are not aligned, this is wrong!') | |
indices = sorted(self.meta.keys()) | |
times = [] | |
for i in indices: | |
default = self.meta[i].get('date_obs', 'date-obs') | |
times.append(astropy.time.Time(self.meta[i].get('t_obs', default))) | |
return astropy.time.Time(times) | |
@property | |
@u.quantity_input | |
def _time_from_dates(self) -> u.s: | |
d = self._dates | |
return (d - d[0]).to('s') | |
@property | |
@u.quantity_input | |
def time(self) -> u.s: | |
# Just a convenient alias, where we assume that the first dimension | |
# of our data is time | |
return self.axis_world_coords_values()[0] | |
@classmethod | |
def from_zarr(cls, filename, channel): | |
root = zarr.open(store=filename, mode='r') | |
ds = root[channel] | |
data = dask.array.from_zarr(ds) | |
wcs = astropy.wcs.WCS(header=ds.attrs['wcs']) | |
meta = ds.attrs['meta'] | |
meta = {int(k): v for k, v in meta.items()} | |
return cls(data, wcs, meta=meta, unit=ds.attrs['unit']) | |
@classmethod | |
def from_map_list(cls, map_list): | |
""" | |
Create a cube from a list of maps. | |
""" | |
# Ensure that the maps are sorted by the t_obs key | |
map_list = sorted(map_list, key=lambda x: astropy.time.Time(x.meta['t_obs'])) | |
data = np.stack([m.data for m in map_list], axis=0) | |
# This assumes that the observation time is stored in the t_obs key | |
# This is non-standard and should probably be fixed. | |
# The reason it doesn't use the .date attribute is because when the maps | |
# are reprojected, the date is set to the date of the coordinate frame | |
# which should all be the same since they have the same wcs. | |
times = astropy.time.Time([map_list[0].meta['t_obs'], map_list[1].meta['t_obs']]) | |
wcs_header = map_list[0].wcs.to_header() | |
# TODO: This should be a gwcs instead of a FITS WCS | |
cards = [('CTYPE3', 'TIME'), | |
('CUNIT3', 's'), | |
('CRVAL3', (times.mjd[0] * u.d).to('s').value), | |
('CRPIX3', 1), | |
('CDELT3', (times[1] - times[0]).to('s').value)] | |
cards += [(f'NAXIS{i+1}', s) for i,s in enumerate(data.shape[::-1])] | |
for c in cards: | |
wcs_header.append(c) | |
# Create mask from the QUALITY keyword *only* if it exists for every header in the map | |
quality = [m.meta.get('quality') for m in map_list] | |
if None in quality: | |
warnings.warn('QUALITY flag not present for all maps. Cannot construct mask.') | |
quality_mask = None | |
else: | |
# This is a per-image mask. The QUALITY key | |
quality_mask = np.zeros(data.shape, dtype=bool) | |
quality_mask[np.where(np.array(quality) != 0)] = True | |
return cls( | |
data, | |
astropy.wcs.WCS(wcs_header), | |
mask=quality_mask, | |
meta={_i:m.meta for _i,m in enumerate(map_list)}, | |
unit=map_list[0].unit, | |
) | |
@classmethod | |
def from_zarr_map_list(cls, filename, channel): | |
""" | |
Read from a zarr store in which each dataset is a sunpy map | |
""" | |
keys = get_zarr_keys(filename, channel) | |
map_list = [map_from_zarr(filename, f'{channel}/{k}') for k in keys] | |
return cls.from_map_list(map_list) | |
def to_zarr(self, filename, chunks=None): | |
root = zarr.open(store=filename, mode='a') | |
if chunks is None: | |
# Chunk such that each time step is a chunk | |
chunks = (1,) + self.data.shape[1:] | |
name = self._wavelength_label | |
if isinstance(self.data, dask.array.Array): | |
dask.array.to_zarr(self.data, filename, component=name) | |
ds = root[name] | |
else: | |
ds = root.create_dataset(name, data=self.data, chunks=chunks) | |
ds.attrs['wcs'] = self.low_level_wcs.to_header() | |
meta = copy.deepcopy(self.meta) | |
problem_keys = ['history', 'comment'] # these can have dtypes that are not JSON serializable | |
for i in meta: | |
for pk in problem_keys: | |
if pk in meta[i]: | |
del meta[i][pk] | |
ds.attrs['meta'] = meta | |
ds.attrs['unit'] = self.unit.to_string() | |
def rechunk(self, chunks): | |
if not isinstance(self.data, dask.array.Array): | |
raise ValueError('Can only rechunk a Dask array') | |
data = self.data.rechunk(chunks) | |
return self._new_instance(data) | |
@u.quantity_input | |
def interpolate_time(self, time: u.s, exclude_masked=False): | |
# NOTE: This function relies on your metadata dict being aligned with your | |
# data. Note that you need to slice this *manually* as NDCube will not do | |
# it for you. Use this function cautiously!!! | |
if not u.quantity.allclose(np.diff(time), np.diff(time)[0], atol=None, rtol=1e-6): | |
raise ValueError('Interpolation time must be evenly spaced.') | |
t = self._time_from_dates.to(time.unit).value | |
t_interp = time.value | |
if exclude_masked: | |
unmasked_indices = np.where(~np.any(self.mask, axis=(1,2))) | |
else: | |
unmasked_indices = np.array(range(t.shape[0])) | |
data_interp = dask.array.map_blocks( | |
lambda y: interp1d(t[unmasked_indices], y[unmasked_indices], | |
axis=0, | |
kind='linear', | |
fill_value='extrapolate')(t_interp), | |
self.data, | |
chunks=t_interp.shape+self.data.chunks[1:], | |
dtype=self.data.dtype | |
) | |
wcs = copy.deepcopy(self.low_level_wcs) | |
wcs.cdelt[2] = np.diff(t_interp)[0] | |
wcs.crval[2] = t_interp[0] | |
wcs.crpix[2] = 1 | |
wcs.cunit[2] = time.unit.to_string() | |
wcs = WCS(header=wcs.to_header()) | |
# NOTE: I'm explicitly not passing the mask on to the new instance | |
# because the mask does not align with the new interpolated data. | |
return type(self)(data_interp, wcs, meta=self.meta, unit=self.unit) | |
def gather(self): | |
"""Return a cube where the data is actually in memory""" | |
return self._new_instance(self.data.compute()) | |
def to_gpu(self): | |
"""Move the cube data to the GPU with cupy""" | |
return self._new_instance(cupy.array(self.data)) | |
def to_dask(self, chunks='auto'): | |
return self._new_instance(dask.array.from_array(self.data, chunks=chunks)) | |
def quick_movie(self, filename=None, dpi=200): | |
vmin, vmax = AsymmetricPercentileInterval(1, 99.5).get_limits(self.data) | |
vmin = max(0, vmin) | |
norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=AsinhStretch(a=0.01)) | |
ani = self.plot(cmap=f'sdoaia{self.wavelength.value:.0f}', norm=norm) | |
if filename is None: | |
return ani | |
else: | |
ani.get_animation().save(filename, writer='ffmpeg', dpi=dpi) | |
def apply_mgn(self, **kwargs): | |
""" | |
Apply MGN filtering to each slice of the cube. | |
This applies this operation in parallel with Dask. | |
You must spin up a client before calling this method. | |
""" | |
client = distributed.get_client() | |
cube = self.gather() | |
gmin = cube.data.min() | |
gmax = cube.data.max() | |
mgn_kwargs = {'gamma_min': gmin, 'gamma_max': gmax} | |
mgn_kwargs.update(kwargs) | |
futures = client.map(sunkit_image.enhance.mgn, cube.data, **mgn_kwargs) | |
data_mgn = np.array(client.gather(futures)) | |
return self._new_instance(data_mgn) | |
class AIACollection(ndcube.NDCollection): | |
def to_zarr(self, filename, **kwargs): | |
for k in self: | |
self[k].to_zarr(filename, **kwargs) | |
@classmethod | |
def from_zarr(cls, filename, channels=None): | |
if channels is None: | |
channels = ['94', '131', '171', '193', '211', '335'] | |
cubes = [(c, AIACube.from_zarr(filename, c)) for c in channels] | |
return cls(cubes, aligned_axes=(0, 1, 2)) | |
@classmethod | |
def from_uninterpolated_cubes(cls, cubes, chunks=None, exclude_masked=False): | |
""" | |
Given a list of cubes, return a collection of these | |
cubes where each has been interpolated to a uniform | |
time array. | |
Parameters | |
---------- | |
cubes : `list` | |
List of tuples, where the first entry is a string representation | |
of the channel and the second entry is the (uninterpolated) cube. | |
""" | |
if chunks is None: | |
chunks = (300, 300) | |
# Find the time array to interpolate to | |
all_times = [cube._time_from_dates for _, cube in cubes] | |
n_max = max([t.shape[0] for t in all_times]) | |
t_max = max([t[-1] for t in all_times]) | |
t_interp = np.linspace(0, t_max.value, n_max) * t_max.unit | |
# Interpolate each cube | |
interp_cubes = [] | |
for label, cube in cubes: | |
chunks_total = cube.data.shape[:1] + chunks | |
interp_cubes.append((label, cube.rechunk(chunks_total).interpolate_time( | |
t_interp, exclude_masked=exclude_masked))) | |
return cls(interp_cubes, aligned_axes=(0,1,2)) | |
def gather(self): | |
return type(self)([(k, self[k].gather()) for k in self], aligned_axes=(0,1,2)) | |
def to_gpu(self): | |
return type(self)([(k, self[k].to_gpu()) for k in self], aligned_axes=(0,1,2)) | |
def to_dask(self, chunks='auto'): | |
return type(self)([(k, self[k].to_dask(chunks=chunks)) for k in self], aligned_axes=(0,1,2)) | |
@property | |
@u.quantity_input | |
def time(self) -> u.s: | |
all_times = [self[k].time for k in self] | |
if not all([np.all(t == all_times[0]) for t in all_times]): | |
raise ValueError('All times are not equal. Cannot return a single time.') | |
return all_times[0] | |
@property | |
@u.quantity_input | |
def wavelengths(self) -> u.angstrom: | |
return u.Quantity([self[k].wavelength for k in self]) | |
def rechunk(self, chunks): | |
return type(self)([(k, self[k].rechunk(chunks)) for k in self], aligned_axes=(0, 1, 2)) | |
def peak_cross_correlation_map(self, channel_a, channel_b, **kwargs): | |
""" | |
Construct map of peak cross-correlation between two channels in each pixel | |
of an AIA map. | |
""" | |
max_cc = sunkit_image.time_lag.max_cross_correlation( | |
self[channel_a].data, | |
self[channel_b].data, | |
self.time, | |
**kwargs, | |
) | |
max_cc_map = sunpy.map.Map(max_cc, self[channel_a].wcs.celestial) | |
max_cc_map.meta['bunit'] = '' | |
max_cc_map.meta['comment'] = f'{channel_a}-{channel_b} cross-correlation' | |
plot_settings = { | |
'cmap': 'plasma', | |
'vmin': 0, | |
'vmax': 1, | |
} | |
plot_settings.update(kwargs.get('plot_settings', {})) | |
max_cc_map.plot_settings.update(plot_settings) | |
return max_cc_map | |
def time_lag_map(self, channel_a, channel_b, **kwargs): | |
""" | |
Construct map of timelag values that maximize the cross-correlation between | |
two channels in each pixel of an AIA map. | |
""" | |
max_timelag = sunkit_image.time_lag.time_lag( | |
self[channel_a].data, | |
self[channel_b].data, | |
self.time, | |
**kwargs, | |
) | |
time_lag_map = sunpy.map.Map(max_timelag, self[channel_a].wcs.celestial) | |
time_lag_map.meta['bunit'] = 's' | |
time_lag_map.meta['comment'] = f'{channel_a}-{channel_b} time lag' | |
plot_settings = { | |
'cmap': 'RdBu_r', | |
} | |
plot_settings.update(kwargs.get('plot_settings', {})) | |
time_lag_map.plot_settings.update(plot_settings) | |
return time_lag_map |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment