Skip to content

Instantly share code, notes, and snippets.

@wtbarnes
Created October 27, 2021 23:23
Show Gist options
  • Save wtbarnes/f04fe7b4a4c38c05e87d99f5074032d1 to your computer and use it in GitHub Desktop.
Save wtbarnes/f04fe7b4a4c38c05e87d99f5074032d1 to your computer and use it in GitHub Desktop.
NDCube subclasses for working with stacked and aligned AIA data
"""
Utilities for working with aligned AIA data cubes
"""
import copy
from os import uname
import warnings
import astropy.time
import astropy.wcs
import astropy.units as u
from astropy.visualization import ImageNormalize, AsinhStretch, AsymmetricPercentileInterval
from astropy.wcs.wcs import WCS
import cupy
import distributed
import dask.array
import ndcube
import numpy as np
from scipy.interpolate import interp1d
import sunkit_image.time_lag
import sunkit_image.enhance
import sunpy.map
import zarr
__all__ = ['']
def map_from_zarr(filename, name):
root = zarr.open(store=filename, mode='r')
ds = root[name]
data = dask.array.from_zarr(ds)
meta = ds.attrs['meta']
return sunpy.map.Map(data, meta)
def fits_to_zarr(filename, root=None, chunks=None):
smap = sunpy.map.Map(filename)
root = zarr.open(store=root, mode='a')
name = f'{smap.wavelength.value:.0f}/{smap.date.isot}'
if isinstance(smap.data, dask.array.Array):
dask.array.to_zarr(smap.data, root, component=name)
ds = root[name]
else:
chunks = smap.data.shape if chunks is None else chunks
ds = root.create_dataset(name, data=smap.data, chunks=chunks)
meta = copy.deepcopy(smap.meta)
problem_keys = ['history', 'comment'] # these can have dtypes that are not JSON serializable
for pk in problem_keys:
if pk in meta:
del meta[pk]
ds.attrs['meta'] = meta
return name
def get_zarr_keys(root, channel):
root = zarr.open(root, mode='r')
return list(root[channel].array_keys())
def get_aia_collection_from_level_2_zarr(zarr_store, chunks=None, exclude_masked=False):
"""
Given a Zarr store containing a bunch of aligned maps, create
a collection of cubes, one for each channel.
Parameters
----------
zarr_store
chunks
"""
root = zarr.open(zarr_store, mode='r')
# Get list of channels
channels = list(root.group_keys())
# For every channel, get all maps-->dictionary of map lists
# Turn each list of maps into a cube
cube_list = []
for c in channels:
times = sorted(list(root[c].array_keys()))
map_list = []
for t in times:
ds = root[f'{c}/{t}']
map_list.append(sunpy.map.Map(dask.array.from_zarr(ds), ds.attrs['meta']))
cube_list.append((c, AIACube.from_map_list(map_list)))
return AIACollection.from_uninterpolated_cubes(cube_list, chunks=chunks, exclude_masked=exclude_masked)
class AIACube(ndcube.NDCube):
def _new_instance(self, data, **kwargs):
"""
Given some alternate data representation, create another cube.
"""
if 'mask' not in kwargs:
mask = self.mask
else:
mask = kwargs['mask']
return type(self)(data, self.wcs, meta=self.meta, unit=self.unit, mask=mask)
@property
def coordinate_frame(self):
return astropy.wcs.utils.wcs_to_celestial_frame(self.wcs)
@property
def observer_coordinate(self):
return self.coordinate_frame.observer
@property
def low_level_wcs(self):
# This is a bad hack that we probably shouldn't be doing
# Annoyingly, once we slice, it is hard to get at a mutable
# FITS WCS. There has to be a better way of doing this...
if isinstance(self.wcs.low_level_wcs, astropy.wcs.wcsapi.SlicedLowLevelWCS):
return self.wcs.low_level_wcs._wcs.wcs
else:
return self.wcs.low_level_wcs.wcs
def _slice_to_map(self, index):
# NOTE: this only works on a 3D cube with (time, lat, lon)
new_cube = self[index, :, :]
return sunpy.map.Map(new_cube.data, astropy.wcs.WCS(new_cube.low_level_wcs.to_header()))
@property
@u.quantity_input
def wavelength(self) -> u.angstrom:
return u.Quantity(self.meta[0]['wavelnth'], self.meta[0]['waveunit'])
@property
def _wavelength_label(self):
return f'{self.wavelength.value:.0f}'
@property
def _dates(self):
# NOTE: this is wrong as soon as the cube is sliced
warnings.warn('If your data and metadata axes are not aligned, this is wrong!')
indices = sorted(self.meta.keys())
times = []
for i in indices:
default = self.meta[i].get('date_obs', 'date-obs')
times.append(astropy.time.Time(self.meta[i].get('t_obs', default)))
return astropy.time.Time(times)
@property
@u.quantity_input
def _time_from_dates(self) -> u.s:
d = self._dates
return (d - d[0]).to('s')
@property
@u.quantity_input
def time(self) -> u.s:
# Just a convenient alias, where we assume that the first dimension
# of our data is time
return self.axis_world_coords_values()[0]
@classmethod
def from_zarr(cls, filename, channel):
root = zarr.open(store=filename, mode='r')
ds = root[channel]
data = dask.array.from_zarr(ds)
wcs = astropy.wcs.WCS(header=ds.attrs['wcs'])
meta = ds.attrs['meta']
meta = {int(k): v for k, v in meta.items()}
return cls(data, wcs, meta=meta, unit=ds.attrs['unit'])
@classmethod
def from_map_list(cls, map_list):
"""
Create a cube from a list of maps.
"""
# Ensure that the maps are sorted by the t_obs key
map_list = sorted(map_list, key=lambda x: astropy.time.Time(x.meta['t_obs']))
data = np.stack([m.data for m in map_list], axis=0)
# This assumes that the observation time is stored in the t_obs key
# This is non-standard and should probably be fixed.
# The reason it doesn't use the .date attribute is because when the maps
# are reprojected, the date is set to the date of the coordinate frame
# which should all be the same since they have the same wcs.
times = astropy.time.Time([map_list[0].meta['t_obs'], map_list[1].meta['t_obs']])
wcs_header = map_list[0].wcs.to_header()
# TODO: This should be a gwcs instead of a FITS WCS
cards = [('CTYPE3', 'TIME'),
('CUNIT3', 's'),
('CRVAL3', (times.mjd[0] * u.d).to('s').value),
('CRPIX3', 1),
('CDELT3', (times[1] - times[0]).to('s').value)]
cards += [(f'NAXIS{i+1}', s) for i,s in enumerate(data.shape[::-1])]
for c in cards:
wcs_header.append(c)
# Create mask from the QUALITY keyword *only* if it exists for every header in the map
quality = [m.meta.get('quality') for m in map_list]
if None in quality:
warnings.warn('QUALITY flag not present for all maps. Cannot construct mask.')
quality_mask = None
else:
# This is a per-image mask. The QUALITY key
quality_mask = np.zeros(data.shape, dtype=bool)
quality_mask[np.where(np.array(quality) != 0)] = True
return cls(
data,
astropy.wcs.WCS(wcs_header),
mask=quality_mask,
meta={_i:m.meta for _i,m in enumerate(map_list)},
unit=map_list[0].unit,
)
@classmethod
def from_zarr_map_list(cls, filename, channel):
"""
Read from a zarr store in which each dataset is a sunpy map
"""
keys = get_zarr_keys(filename, channel)
map_list = [map_from_zarr(filename, f'{channel}/{k}') for k in keys]
return cls.from_map_list(map_list)
def to_zarr(self, filename, chunks=None):
root = zarr.open(store=filename, mode='a')
if chunks is None:
# Chunk such that each time step is a chunk
chunks = (1,) + self.data.shape[1:]
name = self._wavelength_label
if isinstance(self.data, dask.array.Array):
dask.array.to_zarr(self.data, filename, component=name)
ds = root[name]
else:
ds = root.create_dataset(name, data=self.data, chunks=chunks)
ds.attrs['wcs'] = self.low_level_wcs.to_header()
meta = copy.deepcopy(self.meta)
problem_keys = ['history', 'comment'] # these can have dtypes that are not JSON serializable
for i in meta:
for pk in problem_keys:
if pk in meta[i]:
del meta[i][pk]
ds.attrs['meta'] = meta
ds.attrs['unit'] = self.unit.to_string()
def rechunk(self, chunks):
if not isinstance(self.data, dask.array.Array):
raise ValueError('Can only rechunk a Dask array')
data = self.data.rechunk(chunks)
return self._new_instance(data)
@u.quantity_input
def interpolate_time(self, time: u.s, exclude_masked=False):
# NOTE: This function relies on your metadata dict being aligned with your
# data. Note that you need to slice this *manually* as NDCube will not do
# it for you. Use this function cautiously!!!
if not u.quantity.allclose(np.diff(time), np.diff(time)[0], atol=None, rtol=1e-6):
raise ValueError('Interpolation time must be evenly spaced.')
t = self._time_from_dates.to(time.unit).value
t_interp = time.value
if exclude_masked:
unmasked_indices = np.where(~np.any(self.mask, axis=(1,2)))
else:
unmasked_indices = np.array(range(t.shape[0]))
data_interp = dask.array.map_blocks(
lambda y: interp1d(t[unmasked_indices], y[unmasked_indices],
axis=0,
kind='linear',
fill_value='extrapolate')(t_interp),
self.data,
chunks=t_interp.shape+self.data.chunks[1:],
dtype=self.data.dtype
)
wcs = copy.deepcopy(self.low_level_wcs)
wcs.cdelt[2] = np.diff(t_interp)[0]
wcs.crval[2] = t_interp[0]
wcs.crpix[2] = 1
wcs.cunit[2] = time.unit.to_string()
wcs = WCS(header=wcs.to_header())
# NOTE: I'm explicitly not passing the mask on to the new instance
# because the mask does not align with the new interpolated data.
return type(self)(data_interp, wcs, meta=self.meta, unit=self.unit)
def gather(self):
"""Return a cube where the data is actually in memory"""
return self._new_instance(self.data.compute())
def to_gpu(self):
"""Move the cube data to the GPU with cupy"""
return self._new_instance(cupy.array(self.data))
def to_dask(self, chunks='auto'):
return self._new_instance(dask.array.from_array(self.data, chunks=chunks))
def quick_movie(self, filename=None, dpi=200):
vmin, vmax = AsymmetricPercentileInterval(1, 99.5).get_limits(self.data)
vmin = max(0, vmin)
norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=AsinhStretch(a=0.01))
ani = self.plot(cmap=f'sdoaia{self.wavelength.value:.0f}', norm=norm)
if filename is None:
return ani
else:
ani.get_animation().save(filename, writer='ffmpeg', dpi=dpi)
def apply_mgn(self, **kwargs):
"""
Apply MGN filtering to each slice of the cube.
This applies this operation in parallel with Dask.
You must spin up a client before calling this method.
"""
client = distributed.get_client()
cube = self.gather()
gmin = cube.data.min()
gmax = cube.data.max()
mgn_kwargs = {'gamma_min': gmin, 'gamma_max': gmax}
mgn_kwargs.update(kwargs)
futures = client.map(sunkit_image.enhance.mgn, cube.data, **mgn_kwargs)
data_mgn = np.array(client.gather(futures))
return self._new_instance(data_mgn)
class AIACollection(ndcube.NDCollection):
def to_zarr(self, filename, **kwargs):
for k in self:
self[k].to_zarr(filename, **kwargs)
@classmethod
def from_zarr(cls, filename, channels=None):
if channels is None:
channels = ['94', '131', '171', '193', '211', '335']
cubes = [(c, AIACube.from_zarr(filename, c)) for c in channels]
return cls(cubes, aligned_axes=(0, 1, 2))
@classmethod
def from_uninterpolated_cubes(cls, cubes, chunks=None, exclude_masked=False):
"""
Given a list of cubes, return a collection of these
cubes where each has been interpolated to a uniform
time array.
Parameters
----------
cubes : `list`
List of tuples, where the first entry is a string representation
of the channel and the second entry is the (uninterpolated) cube.
"""
if chunks is None:
chunks = (300, 300)
# Find the time array to interpolate to
all_times = [cube._time_from_dates for _, cube in cubes]
n_max = max([t.shape[0] for t in all_times])
t_max = max([t[-1] for t in all_times])
t_interp = np.linspace(0, t_max.value, n_max) * t_max.unit
# Interpolate each cube
interp_cubes = []
for label, cube in cubes:
chunks_total = cube.data.shape[:1] + chunks
interp_cubes.append((label, cube.rechunk(chunks_total).interpolate_time(
t_interp, exclude_masked=exclude_masked)))
return cls(interp_cubes, aligned_axes=(0,1,2))
def gather(self):
return type(self)([(k, self[k].gather()) for k in self], aligned_axes=(0,1,2))
def to_gpu(self):
return type(self)([(k, self[k].to_gpu()) for k in self], aligned_axes=(0,1,2))
def to_dask(self, chunks='auto'):
return type(self)([(k, self[k].to_dask(chunks=chunks)) for k in self], aligned_axes=(0,1,2))
@property
@u.quantity_input
def time(self) -> u.s:
all_times = [self[k].time for k in self]
if not all([np.all(t == all_times[0]) for t in all_times]):
raise ValueError('All times are not equal. Cannot return a single time.')
return all_times[0]
@property
@u.quantity_input
def wavelengths(self) -> u.angstrom:
return u.Quantity([self[k].wavelength for k in self])
def rechunk(self, chunks):
return type(self)([(k, self[k].rechunk(chunks)) for k in self], aligned_axes=(0, 1, 2))
def peak_cross_correlation_map(self, channel_a, channel_b, **kwargs):
"""
Construct map of peak cross-correlation between two channels in each pixel
of an AIA map.
"""
max_cc = sunkit_image.time_lag.max_cross_correlation(
self[channel_a].data,
self[channel_b].data,
self.time,
**kwargs,
)
max_cc_map = sunpy.map.Map(max_cc, self[channel_a].wcs.celestial)
max_cc_map.meta['bunit'] = ''
max_cc_map.meta['comment'] = f'{channel_a}-{channel_b} cross-correlation'
plot_settings = {
'cmap': 'plasma',
'vmin': 0,
'vmax': 1,
}
plot_settings.update(kwargs.get('plot_settings', {}))
max_cc_map.plot_settings.update(plot_settings)
return max_cc_map
def time_lag_map(self, channel_a, channel_b, **kwargs):
"""
Construct map of timelag values that maximize the cross-correlation between
two channels in each pixel of an AIA map.
"""
max_timelag = sunkit_image.time_lag.time_lag(
self[channel_a].data,
self[channel_b].data,
self.time,
**kwargs,
)
time_lag_map = sunpy.map.Map(max_timelag, self[channel_a].wcs.celestial)
time_lag_map.meta['bunit'] = 's'
time_lag_map.meta['comment'] = f'{channel_a}-{channel_b} time lag'
plot_settings = {
'cmap': 'RdBu_r',
}
plot_settings.update(kwargs.get('plot_settings', {}))
time_lag_map.plot_settings.update(plot_settings)
return time_lag_map
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment