Skip to content

Instantly share code, notes, and snippets.

@gewitterblitz
Created October 27, 2019 02:09
Show Gist options
  • Save gewitterblitz/0dead913c57d1f3b3389661145c14e7f to your computer and use it in GitHub Desktop.
Save gewitterblitz/0dead913c57d1f3b3389661145c14e7f to your computer and use it in GitHub Desktop.
from glue.config import data_factory
from glue.core import Data
from glue.core.coordinates import Coordinates
from astropy import units as u
from astropy.wcs.wcsapi import BaseLowLevelWCS
import pyart
import numpy as np
class WCSCoordinateWrapper(BaseLowLevelWCS):
"""
WCS Wrapper for a glue Coordinates object
"""
def __init__(self, coords):
self._coords = coords
@property
def pixel_n_dim(self):
return len(self._coords.axis_labels)
@property
def world_n_dim(self):
return len(self._coords.axis_labels)
@property
def world_axis_physical_types(self):
return [None] * self.world_n_dim
@property
def world_axis_units(self):
return [''] * self.world_n_dim
def pixel_to_world_values(self, *pixel_arrays):
return self._coords.pixel2world(*pixel_arrays)
def array_index_to_world_values(self, *index_arrays):
return self._coords.pixel2world(*index_arrays[::-1])
def world_to_pixel_values(self, *world_arrays):
return self._coords.world2pixel(*world_arrays)
def world_to_array_index_values(self, *world_arrays):
pixel_arrays = self.world_to_pixel_values(*world_arrays)[::-1]
array_indices = tuple(np.asarray(np.floor(pixel + 0.5), dtype=np.int) for pixel in pixel_arrays)
return array_indices[0] if self.pixel_n_dim == 1 else array_indices
@property
def world_axis_object_components(self):
return [('world{0}'.format(i), 0, 'value') for i in range(self.world_n_dim)]
@property
def world_axis_object_classes(self):
return {'world{0}'.format(i): (u.Quantity, (), {'unit': u.one}) for i in range(self.world_n_dim)}
class PyartGeospatialLonLatCoordinates(Coordinates):
def __init__(self, grd=None):
self.grd = grd
self.projparams = grd.get_projparams()
self.axis_labels = ['altitude','latitude','longitude']
def axis_label(self, axis):
return self.axis_labels[axis]
def pixel2world(self, *pixel):
px, py, pz = pixel
x = self.grd.x['data'].min() + px * (self.grd.x['data'].max() - self.grd.x['data'].min())/(self.grd.nx-1)
y = self.grd.y['data'].min() + py * (self.grd.y['data'].max() - self.grd.y['data'].min())/(self.grd.ny-1)
z = self.grd.z['data'].min() + pz * (self.grd.z['data'].max() - self.grd.z['data'].min())/(self.grd.nz-1)
lon, lat = pyart.core.transforms.cartesian_to_geographic_aeqd(x, y, self.projparams.get('lon_0'), self.projparams.get('lat_0'))
lon, lat, z = np.broadcast_arrays(lon, lat, z)
return lon, lat, z
def world2pixel(self, *world):
lon, lat, z = world
x, y = pyart.core.transforms.geographic_to_cartesian_aeqd(lon, lat, self.projparams.get('lon_0'), self.projparams.get('lat_0'))
px = (x - self.grd.x['data'].min()) / (self.grd.x['data'].max() - self.grd.x['data'].min()) * (self.grd.nx-1)
py = (y - self.grd.y['data'].min()) / (self.grd.y['data'].max() - self.grd.y['data'].min()) * (self.grd.ny-1)
pz = (z - self.grd.z['data'].min()) / (self.grd.z['data'].max() - self.grd.z['data'].min()) * (self.grd.nz-1)
px, py, pz = np.broadcast_arrays(px, py, pz)
return px, py, pz
@property
def wcsaxes_dict(self):
"""
Dictionary to initialize WCSAxes
"""
return {'wcs': WCSCoordinateWrapper(self)}
def dependent_axes(self, axis):
return (0, 1, 2)
def __gluestate__(self, context):
return dict(affine=list(self.affine)[:6], crs_dict=self.crs_dict)
@classmethod
def __setgluestate__(cls, rec, context):
return cls(affine=Affine(*rec['affine']), crs_dict=rec['crs_dict'])
def is_radar(file, **kwargs):
return file.endswith('radar_grid')
@data_factory('radar grid reader', identifier=is_radar, priority=10000)
def read_grid(file):
grd = pyart.io.read_grid('radar_grid')
fields = grd.fields.keys()
gluedata = Data(label='grd')
gluedata.coords = PyartGeospatialLonLatCoordinates(grd)
for field in fields:
gluedata[field] = grd.fields[field]['data']
return gluedata
if __name__ == "__main__":
data = read_grid('radar_grid')
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, **data.coords.wcsaxes_dict, slices=('x', 'y', 0))
fig.savefig('test.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment