Skip to content

Instantly share code, notes, and snippets.

@markusritschel
Last active November 15, 2023 23:58
Show Gist options
  • Save markusritschel/e54ec6cba5f5fd7b71e2135892660cbb to your computer and use it in GitHub Desktop.
Save markusritschel/e54ec6cba5f5fd7b71e2135892660cbb to your computer and use it in GitHub Desktop.
Accessor for cartopy.GeoAxes objects to simplify adding features to (stereographic) map plots
from abc import ABC, abstractmethod
from matplotlib import pyplot as plt, path as mpath
import functools
import cartopy
import cartopy.crs
import cartopy.mpl.geoaxes
import numpy as np
def register_geoaxes_accessor(accessor_name):
"""
Register an accessor for a cartopy.GeoAxes object.
Example
-------
>>> @register_geoaxes_accessor("my_accessor")
>>> class MyCustomAccessor:
>>> def some_method(self):
>>> pass
>>>
>>> ax = plt.subplot(projection=cartopy.crs.NorthPolarStereo())
>>> ax.my_accessor.some_method()
"""
def actual_decorator(cls):
@functools.wraps(cls)
def accessor(geo_axes):
return cls(geo_axes)
setattr(cartopy.mpl.geoaxes.GeoAxes, accessor_name, property(accessor))
return cls
return actual_decorator
class GeoAxesAccessor(ABC):
def __init__(self, ax):
self.geo_axes = ax
def add_ocean(self, **kwargs):
kwargs.setdefault('zorder', 0)
self.geo_axes.add_feature(cartopy.feature.OCEAN, **kwargs)
def add_land(self, **kwargs):
kwargs.setdefault('zorder', 2)
self.geo_axes.add_feature(cartopy.feature.LAND, **kwargs)
def add_coastlines(self, **kwargs):
self.geo_axes.coastlines(**kwargs)
def set_extent(self, extent, crs=cartopy.crs.PlateCarree()):
self.geo_axes.set_extent(extent, crs)
@abstractmethod
def add_gridlines(self):
pass
@abstractmethod
def add_features(self):
pass
@register_geoaxes_accessor("polar")
class StereographicAxisAccessor(GeoAxesAccessor):
"""An accessor to handle features and finishing of stereographic plots produced with `cartopy`.
Can handle both :class:`ccrs.NorthPolarStereo` and :class:`ccrs.SouthPolarStereo` projections."""
def __init__(self, ax):
super().__init__(ax)
self._type = type(ax._projection_init[1]['projection'])
self._pole = {cartopy.crs.SouthPolarStereo: 'south',
cartopy.crs.NorthPolarStereo: 'north'}[self._type]
self._lat_limits = {'south': [-90, -50],
'north': [50, 90]}[self._pole]
self.set_extent([-180, 180, *self._lat_limits])
self._lat_breakpoint = {'south': -80,
'north': 80}[self._pole]
self._lon_grid_spacing = 30
self._draw_labels = False # or should this rather be an attribute of self.geo_axes._draw_labels ?
def info(self):
info_dict = dict(
projection=self._type,
pole=self._pole,
lat_limits=self._lat_limits,
lon_grid_spacing=self._lon_grid_spacing
)
return info_dict
def add_features(self, **kwargs):
"""Perform the following steps:
- add ocean
- add land
- add coastlines
- add ruler
- make the boundary circular
- add gridlines
"""
coastlines_kwargs = kwargs.pop('coastlines_kwargs', {})
ruler_kwargs = kwargs.pop('ruler_kwargs', {})
ocean_kwargs = kwargs.pop('ocean_kwargs', {})
land_kwargs = kwargs.pop('land_kwargs', {})
gridlines_kwargs = kwargs.pop('gridlines_kwargs', {})
self._lon_grid_spacing = ruler_kwargs.get('segment_length', 30)
self.add_ocean(**ocean_kwargs)
self.add_land(**land_kwargs)
self.add_coastlines(**coastlines_kwargs)
self.add_ruler(**ruler_kwargs)
self.make_circular()
self.add_gridlines(**gridlines_kwargs)
def add_ruler(self, **kwargs):
kwargs.setdefault('segment_length', self._lon_grid_spacing)
add_circular_ruler(self.geo_axes, **kwargs)
def make_circular(self):
set_circular_boundary(self.geo_axes)
def add_gridlines(self, **kwargs):
kwargs.setdefault('zorder', 1)
kwargs.setdefault('linestyle', '-')
kwargs.setdefault('linewidth', 0.5)
kwargs.setdefault('color', 'gray')
kwargs.setdefault('alpha', 0.7)
lat0, lat1 = self._lat_limits
fac1, fac2 = {'south': [2, 1],
'north': [1, 2]}[self._pole]
lat_grid_spacing = 10
ygrid_locs = np.arange(lat0, lat1 + 1, lat_grid_spacing)
gl = self.geo_axes.gridlines(xlocs=np.arange(-180, 180, fac1*self._lon_grid_spacing),
ylim=[lat0, self._lat_breakpoint],
ylocs=ygrid_locs,
draw_labels=self._draw_labels,
**kwargs
)
gl2 = self.geo_axes.gridlines(xlocs=np.arange(-180, 180, fac2*self._lon_grid_spacing),
ylim=[self._lat_breakpoint, lat1],
ylocs=ygrid_locs,
draw_labels=self._draw_labels,
**kwargs
)
gl.ylabel_style = {'color': '.2', 'weight': 'bold', 'fontsize': 'smaller'}
def set_circular_boundary(ax):
"""Compute a circle in axes coordinates, which we can use as a boundary for the map.
We can pan/zoom as much as we like – the boundary will be permanently circular."""
theta = np.linspace(0, 2*np.pi, 100)
center, radius = [0.5, 0.5], 0.5
vertices = np.vstack([np.sin(theta), np.cos(theta)]).T
circle = mpath.Path(vertices*radius + center)
ax.set_boundary(circle, transform=ax.transAxes)
return
def add_circular_ruler(ax, segment_length=30, offset=0, primary_color='k', secondary_color='w', width=None):
"""Add a ruler around a polar stereographic plot.
Parameters
----------
ax : GeoAxes
The GeoAxes object to which the ruler should be added
segment_length : int
The length of each segment in degrees
offset : int
An optional offset
primary_color : str
The color of the background ruler segments
secondary_color : str
The color of the top ruler segments
width : int
The thickness of the ruler
"""
def plot_circle(degrees, radius=0.5, **kwargs):
"""Plot a circle of given radius (based on Axis dimensions) for a list of degrees."""
ax = kwargs.pop("ax", plt.gca())
arc_angles = np.deg2rad(degrees)
arc_xs = radius * np.cos(arc_angles) + 0.5
arc_ys = radius * np.sin(arc_angles) + 0.5
ax.plot(arc_xs, arc_ys, transform=ax.transAxes, solid_capstyle="butt", **kwargs)
width = ax.bbox.width / 80 if width is None else width
if (360 / segment_length) % 2 != 0:
raise Warning(
"`segment_length` must fit 2n times into 360 so that the segments of the "
"ruler can be equally distributed on a circle."
)
# background circle (default: black, slightly broader)
arc_array = np.linspace(0, 360, 360, endpoint=True)
plot_circle(arc_array, color=primary_color, lw=width * 2 + 1, zorder=999, ax=ax)
# white circle segments on top
bnds_array = (
np.arange(0, 360, segment_length).reshape((-1, 2))
+ 90 # to start at the top instead of at the right
+ offset
)
arc_array = np.hstack(
[
np.hstack([np.linspace(*bnds, segment_length, endpoint=True), np.array(np.nan)])
for bnds in bnds_array
]
)
plot_circle(arc_array, color=secondary_color, lw=width * 2, zorder=1000, ax=ax)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment