Skip to content

Instantly share code, notes, and snippets.

@xuhdev
Last active August 29, 2015 14:01
Show Gist options
  • Save xuhdev/6609672cfd7feaeb3982 to your computer and use it in GitHub Desktop.
Save xuhdev/6609672cfd7feaeb3982 to your computer and use it in GitHub Desktop.
Plot 2D histogram and its contour
# This is an example of input parameters
# the first plot starts here
plot1 = {}
# Path to the data file
plot1['data_file'] = 'bayes1.dat'
# which columns to plot?
plot1['col1'] = 2
plot1['col2'] = 3
# title and labels
plot1['label_x'] = 'Concentration'
plot1['label_y'] = 'M200'
plot1['title'] = 'redshift = 0.2'
# reference points
plot1['ref_points'] = [ (5, 2e15) ]
# number of bins
plot1['bins'] = 100
# the second plot starts here
plot2 = {}
plot2['data_file'] = 'bayes2.dat'
plot2['col1'] = 2
plot2['col2'] = 3
plot2['label_x'] = 'Concentration'
plot2['label_y'] = 'M200'
plot2['title'] = 'redshift = 0.4'
plot2['bins'] = 100
#!/usr/bin/env python2
# plot the 2D histogram of the data with contours
from __future__ import print_function
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import matplotlib as mpl
import numpy as np
import sys
import types
from scipy.stats import gaussian_kde
class Color(object):
current_color = 'r'
@staticmethod
def get_current_color():
return Color.current_color
@staticmethod
def get_next_color():
if Color.current_color == 'r':
Color.current_color = 'b'
elif Color.current_color == 'b':
Color.current_color = 'k'
elif Color.current_color == 'k':
Color.current_color = 'r'
return Color.current_color
@staticmethod
def reset_color():
Color.current_color = 'r'
# Use LaTeX rendering
mpl.rc('text', usetex=True)
all_config = {}
_ = {}
exec(sys.stdin, _, all_config)
del _
plots_config = all_config['plots']
levels = all_config.get('levels', [0.68, 0.95, 0.997])
num_plots = len(plots_config)
num_plots_x = 0
num_plots_y = 0
if num_plots % 2 == 0 and num_plots > 2:
num_plots_y = num_plots / 2
num_plots_x = 2
else:
num_plots_y = num_plots
num_plots_x = 1
fig = plt.figure(figsize = (18, 11), dpi = 100)
# The maximum limit of all figures
max_lim = {}
max_lim['x_low'] = all_config.get('x_low', float('Inf'))
max_lim['x_high'] = all_config.get('x_high', -float('Inf'))
max_lim['y_low'] = all_config.get('y_low', float('Inf'))
max_lim['y_high'] = all_config.get('y_high', -float('Inf'))
axes = [] # All the subplots
for v in range(len(plots_config)):
config = plots_config[v]
ax = fig.add_subplot(num_plots_x, num_plots_y, v + 1, title = config.get('title', ''))
axes.append(ax)
ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))
ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
ax.set_xlabel(config.get('label_x', ''))
ax.set_ylabel(config.get('label_y', ''))
col1 = config.get('col1', 0)
col2 = config.get('col2', 1)
# plot text
ax.text(0.9, 0.9, config.get('text', ''), ha = 'right', va = 'top', transform = ax.transAxes)
# get the list of data files if one need to plot multiple contours on the same plot
data_files = config['data_file']
if isinstance(data_files, types.StringTypes):
data_files = [data_files]
data_function1 = config.get('data_function1', None)
data_function2 = config.get('data_function2', None)
lines = []
for data_file in data_files:
num_levels = np.zeros(len(levels))
d = np.loadtxt(data_file)
# call functions on data
if data_function1 != None:
d[:, col1] = map(data_function1, d[:, col1])
if data_function2 != None:
d[:, col2] = map(data_function2, d[:, col2])
c1 = d[:, col1]
c2 = d[:, col2]
# seq = range(len(c1))
# Plot the data points if the user asks for it
if config.get('plot_points', False):
img = plt.scatter(c1, c2, s = 5, marker = 'x') #, cmap = mpl.cm.afmhot, c = seq)
# cb = fig.colorbar(img, ticks = [0, len(c1)])
h, xedges, yedges = np.histogram2d(c1, c2, bins=config.get('bins', 100), normed=True)
h = h * (yedges[1] - yedges[0]) * (xedges[1] - xedges[0])
xcenters = xedges[:-1] + 0.5 * (xedges[1:] - xedges[:-1])
ycenters = yedges[:-1] + 0.5 * (yedges[1:] - yedges[:-1])
# genearate the kernel
X, Y = np.meshgrid(xcenters, ycenters)
positions = np.vstack([X.ravel(), Y.ravel()])
kernel = gaussian_kde(np.vstack([c1, c2]))
Z = kernel(positions) * ((yedges[1] - yedges[0]) * (xedges[1] - xedges[0]))
# sort the probability density to calculate the cumulative probability
indices = np.argsort(Z, axis = None)[::-1]
for i in range(1, len(indices)):
Z[indices[i]] = Z[indices[i]] + Z[indices[i - 1]]
Z = np.reshape(Z, X.shape).T
# plot the best fit point
best_fit_pnt = np.unravel_index(indices[0], Z.shape) # find the best fit point
plt.scatter(xcenters[best_fit_pnt[0]], ycenters[best_fit_pnt[1]], c = Color.get_current_color(), s = 30)
print('Best fit point is ({0},{1})'.format(xcenters[best_fit_pnt[0]], ycenters[best_fit_pnt[1]]), file=sys.stderr)
# plot the contour
con = plt.contour(xcenters, ycenters, Z, levels = levels, colors = Color.get_current_color())
plt.clabel(con, inline=1, fontsize=10)
lines.append(con.collections[0])
# check the axis limit
x_low, x_high = ax.get_xlim()
if max_lim['x_low'] > x_low:
max_lim['x_low'] = x_low
if max_lim['x_high'] < x_high:
max_lim['x_high'] = x_high
y_low, y_high = ax.get_ylim()
if max_lim['y_low'] > y_low:
max_lim['y_low'] = y_low
if max_lim['y_high'] < y_high:
max_lim['y_high'] = y_high
# count the number of points on different levels
flat_h = np.ravel(h)
flat_Z = np.ravel(Z)
for i in range(len(flat_Z)):
for j in range(len(levels)):
if flat_Z[i] < levels[j]:
num_levels[j] = num_levels[j] + flat_h[i]
for i in range(len(levels)):
print('Level ' + str(levels[i]) + ' contains ' + str(num_levels[i]) + ' of points', file=sys.stderr)
print('---', file=sys.stderr)
Color.get_next_color()
Color.reset_color()
print('------------', file=sys.stderr)
# plots the ref_points
ref_points = config.get('ref_points', None)
if data_function1 != None:
if ref_points != None:
for p in ref_points:
p[0] = data_function1(p[0])
if data_function2 != None:
if ref_points != None:
for p in ref_points:
p[1] = data_function2(p[1])
# plot ref points
if ref_points != None:
zipped_points = zip(*ref_points)
plt.scatter(zipped_points[0], zipped_points[1], c = 'r', marker = '*', s = 30)
# plot the legend
if config.get('legend', None) != None:
plt.legend(lines, config.get('legend'), loc = 'lower left')
if all_config.get('same_range', False):
print('Using the same axis range for all plots', file=sys.stderr)
for ax in axes:
ax.set_xlim(max_lim['x_low'], max_lim['x_high'])
ax.set_ylim(max_lim['y_low'], max_lim['y_high'])
fig.suptitle(all_config.get('main_title', ''), fontsize = 20)
if len(sys.argv) > 1:
plt.savefig(sys.argv[1])
else:
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment