Skip to content

Instantly share code, notes, and snippets.

@Sunmish
Created March 28, 2024 08:00
Show Gist options
  • Save Sunmish/9afa5fdde8b6754eb6e5cbc8660ae311 to your computer and use it in GitHub Desktop.
Save Sunmish/9afa5fdde8b6754eb6e5cbc8660ae311 to your computer and use it in GitHub Desktop.
Basic WCS axes plotting.
#! /usr/bin/env python
import os
import numpy as np
from astropy.io import fits
from astropy.wcs import WCS
from astropy.visualization import ZScaleInterval, AsymmetricPercentileInterval, simple_norm
from astropy.wcs.utils import proj_plane_pixel_scales
from astropy.visualization.wcsaxes import SphericalCircle
from astropy.table import Table
from matplotlib import pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec, SubplotSpec
from mpl_toolkits.axes_grid1.anchored_artists import (AnchoredEllipse,
AnchoredSizeBar)
from matplotlib.patches import Ellipse
from matplotlib import rc
from matplotlib.font_manager import FontProperties
from regions import Regions
import cmasher as cmr
LOCATIONS = {
'upper right' : 1,
'upper left' : 2,
'lower left' : 3,
'lower right' : 4,
'right' : 5,
'center left' : 6,
'center right' : 7,
'lower center' : 8,
'upper center' : 9,
'center' : 10
}
def lut_to_cmap(lut_file, divide_by_255=True):
'''Convert three column (rgb) table to mpl colormap.'''
lut = np.genfromtxt(lut_file)
if divide_by_255:
cm = mpl.colors.ListedColormap(lut/255.0)
else:
cm = mpl.colors.ListedColormap(lut)
return cm
try:
sls = lut_to_cmap(os.environ["DROPBOX"] + "/scripts/sls.txt", False)
except KeyError:
sls = None
def show_bdsf_catalogue(catalogue, ax, color,
do_ellipses=False,
marker="o",
markersize=150):
table = Table.read(catalogue)
for i in range(len(table)):
if do_ellipses:
e = Ellipse((table["RA"][i], table["DEC"][i]),
width=table["Maj"][i],
height=table["Min"][i],
angle=table["PA"][i],
edgecolor=color,
facecolor="none",
transform=ax.get_transform("fk5")
)
ax.add_patch(e)
else:
ax.scatter(table["RA"][i], table["DEC"][i],
s=markersize,
marker=marker,
color=color,
transform=ax.get_transform("fk5")
)
return ax
def get_axes_from_gs(gs, fig, N):
axes = []
for i in range(N):
sb = SubplotSpec(gs, i)
sp = sb.get_position(figure=fig).get_points().flatten()
x = sp[0]
y = sp[1]
dx = sp[2]-x
dy = sp[3]-y
axes.append([x, y, dx, dy])
return axes
def get_last2d(array):
"""https://stackoverflow.com/a/27111239"""
if array.ndim <= 2:
return array
else:
slc = [0] * (array.ndim - 2)
slc += [slice(None), slice(None)]
return array[tuple(slc)]
def auto_v(pmin, pmax, data):
"""Determine vmin and vmax from AsymmetricPercentileInterval.
A mirror of aplpy's old auto_v function.
"""
interval = AsymmetricPercentileInterval(pmin, pmax)
vmin, vmax = interval.get_limits(data)
vmin = -0.1 * (vmax - vmin) + vmin
vmax = 0.1 * (vmax - vmin) + vmax
return vmin, vmax
def recenter(s1, wcs, centre, fov, figsize, axes):
xpix, ypix = wcs.all_world2pix(centre[0], centre[1], 0)
rxy = (figsize[0]/figsize[1]) * axes[2]/axes[3]
print(figsize)
print(axes)
# rxy = 1.
print(rxy)
ypix1 = wcs.wcs_world2pix(centre[0], centre[1]+0.5*fov[0], 0)[1]
ypix2 = wcs.wcs_world2pix(centre[0], centre[1]-0.5*fov[0], 0)[1]
y_range = abs(ypix1 - ypix2)
xpix1 = wcs.wcs_world2pix(centre[0]-0.5*fov[0], centre[1], 0)[0]
xpix2 = wcs.wcs_world2pix(centre[0]+0.5*fov[0], centre[1], 0)[0]
x_range = abs(xpix1 - xpix2)
s1.set_ylim([ypix-0.5*y_range, ypix+0.5*y_range])
s1.set_xlim([xpix-0.5*x_range*rxy, xpix+0.5*x_range*rxy])
return s1
def recenter2(s1, wcs, centre, fov):
"""
https://aplpy.readthedocs.io/en/stable/_modules/aplpy/core.html#FITSFigure.recenter
"""
xpix, ypix = wcs.all_world2pix(centre[0], centre[1], 0)
pix_scale = proj_plane_pixel_scales(wcs)
sx, sy = pix_scale[0], pix_scale[1]
dx_pix = fov[0] / sx * 0.5
dy_pix = fov[1] / sy * 0.5
s1.set_xlim([xpix - dx_pix, xpix+dx_pix])
s1.set_ylim([ypix - dy_pix, ypix+dy_pix])
return s1
def make_axes(header, fig,
ax=None,
gs_ax=None,
data=None,
vsetting=None,
psetting=None,
scale=1000.,
cmap="gray",
centre=None,
fov=None,
do_axis_labels=True,
fontlabels=14.,
fontticks=14.,
do_colorbar=True,
colorbar_label=None,
colorbar_thickness=0.0075,
colorbar_pad=0.0005,
colorbar_label_pad=0.,
colorbar_orientation="vertical",
colorbar_label_on_top=False,
colorbar_label_on_top_alignment="right",
do_beam=True,
rotate_labels=False,
norm=None,
sans=True,
tick_direction="in",
tick_colour="black",
tick_size=8,
aspect="auto",
):
if cmap == "sls":
cmap = sls
if sans:
params = {'text.usetex': False, 'mathtext.fontset': "dejavusans"}
plt.rcParams.update(params)
wcs = WCS(header).celestial
if ax is not None:
s1 = plt.axes(ax, projection=wcs)
elif gs_ax is not None:
s1 = fig.add_subplot(gs_ax, projection=wcs)
else:
raise RuntimeError("Either `ax` or `gs_ax` should be specified.")
figsize = fig.get_size_inches()
print(figsize)
if data is not None:
if norm is None:
if vsetting is None:
if psetting is None:
zscale = ZScaleInterval()
vmin, vmax = zscale.get_limits(data)
else:
vmin, vmax = auto_v(psetting[0], psetting[1], data)
vmin *= scale
vmax *= scale
else:
vmin, vmax = vsetting
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
im = s1.imshow(data*scale,
norm=norm,
cmap=cmap,
origin="lower",
aspect="auto"
)
if (centre is not None) and (fov is not None) and (figsize is not None):
# s1 = recenter(s1, wcs, centre, fov, figsize, ax)
s1 = recenter2(s1, wcs, centre, fov)
if do_axis_labels:
ra = s1.coords[0]
dec = s1.coords[1]
for n, axis in enumerate([ra, dec]):
axis.set_ticks(
size=tick_size,
direction=tick_direction,
color=tick_colour,
)
axis.display_minor_ticks(True)
axis.set_minor_frequency(4)
axis.set_ticklabel(
color="black",
size=fontticks,
pad=5.,
exclude_overlapping=True
)
axis.tick_params(
which="minor",
length=tick_size*0.5
)
axis.set_auto_axislabel(False)
ra.set_axislabel(r"$\alpha_{\mathrm{\mathsf{J2000}}}$",
fontsize=fontlabels
)
dec.set_axislabel(r"$\delta_{\mathrm{\mathsf{J2000}}}$",
fontsize=fontlabels
)
if rotate_labels:
dec.set_ticklabel(rotation="vertical")
s1.set_aspect(aspect)
if do_colorbar:
ax = s1.get_position().bounds
fig_aspect = figsize[0]/figsize[1]
if colorbar_orientation == "horizontal":
cbax = [ax[0], ax[1]+ax[3]+colorbar_pad*fig_aspect, ax[2], colorbar_thickness*fig_aspect]
label_pad = 7
else:
cbax = [ax[2]+ax[0]+colorbar_pad/fig_aspect, ax[1], colorbar_thickness/fig_aspect, ax[3]]
label_pad = colorbar_label_pad
colorbar_axis = fig.add_axes(cbax)
colorbar = mpl.colorbar.ColorbarBase(colorbar_axis,
cmap=plt.get_cmap(cmap),
norm=norm,
orientation=colorbar_orientation
)
if colorbar_label is None:
colorbar_label = r"Stokes $I$ / mJy PSF$^{-1}$"
if colorbar_label_on_top:
colorbar_axis.set_title(colorbar_label, fontsize=fontlabels, ha=colorbar_label_on_top_alignment)
else:
colorbar.set_label(colorbar_label, fontsize=fontlabels, labelpad=label_pad)
if colorbar_orientation == "horizontal":
colorbar.ax.xaxis.set_ticks_position("top")
colorbar.ax.xaxis.set_label_position("top")
else:
colorbar.ax.yaxis.set_ticks_position("right")
colorbar.ax.yaxis.set_label_position("right")
colorbar.ax.tick_params(which="major",
labelsize=fontticks,
length=4.,
direction="out",
labelcolor="black"
)
else:
colorbar = None
if do_beam:
pix_scale = proj_plane_pixel_scales(wcs)
sx, sy = pix_scale[0], pix_scale[1]
bmaj = header["BMAJ"]
bmin = header["BMIN"]
bpa = header["BPA"]
xypixscale = np.sqrt(sx*sy)
bmaj_pix = bmaj / xypixscale
bmin_pix = bmin / xypixscale
beam = AnchoredEllipse(s1.transData,
width=bmin_pix,
height=bmaj_pix,
angle=bpa,
loc="lower left",
frameon=True,
pad=0.2,
borderpad=1.
)
beam.ellipse.set_edgecolor("black")
beam.ellipse.set_facecolor("black")
s1.add_artist(beam)
return s1, fig, colorbar
def show_contours(ax, contour_image, color, levels,
linewidth=1.2,
linestyle="-",
zorder=None):
"""Apply contours to an existing axis."""
with fits.open(contour_image) as f:
ax.contour(
np.squeeze(f[0].data),
levels=levels,
colors=color,
linewidths=linewidth,
linestyles=linestyle,
transform=ax.get_transform(
WCS(f[0].header).celestial
),
zorder=zorder
)
return ax
def show_beam(ax, wcs, bmaj, bmin, bpa,
loc="lower left",
color="black"):
pix_scale = proj_plane_pixel_scales(wcs)
sx, sy = pix_scale[0], pix_scale[1]
xypixscale = np.sqrt(sx*sy)
bmaj_pix = bmaj / xypixscale
bmin_pix = bmin / xypixscale
beam = AnchoredEllipse(ax.transData,
width=bmin_pix,
height=bmaj_pix,
angle=bpa,
loc=loc,
frameon=True,
pad=0.5,
borderpad=1.,
snap=False
)
beam.ellipse.set_edgecolor(color)
beam.ellipse.set_facecolor(color)
ax.add_artist(beam)
return ax
def show_scalebar(ax, wcs, scale, label, fontsize,
loc="upper right",
color="white",
frame=False,
borderpad=0.4,
pad=0.5,
size_vertical=0.,
**kwargs):
"""Show linear scale bar on an axis."""
pix_scale = proj_plane_pixel_scales(wcs)
sx, sy = pix_scale[0], pix_scale[1],
xypixscale = np.sqrt(sx*sy)
length = scale / xypixscale
scalebar = AnchoredSizeBar(
ax.transData, length, label, LOCATIONS[loc],
pad=pad,
borderpad=borderpad,
sep=5,
frameon=frame,
color=color,
# size_vertical=size_vertical,
fontproperties=FontProperties(size=fontsize),
**kwargs
)
ax.add_artist(scalebar)
return ax, scalebar
def show_region(ax, wcs, region_file, color=None):
reg = Regions.read(region_file, format="ds9")
for r in reg:
r1 = r.to_pixel(wcs)
if color is None:
r1.plot(ax=ax, zorder=100)
else:
r1.plot(ax=ax, color=color, zorder=100)
return ax
def show_png(ax, png,
vertical_flip=False,
horizontal_flip=False,
interpolation="nearest"):
try:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
except ImportError:
try:
import Image
except ImportError:
raise ImportError("The Python Imaging Library (PIL) is required to read in RGB images")
else:
image = Image.open(png)
if vertical_flip:
image = image.transpose(Image.FLIP_TOP_BOTTOM)
if horizontal_flip:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
im = ax.imshow(image,
interpolation=interpolation,
origin="lower",
)
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment