Last active
August 29, 2015 14:01
-
-
Save xuhdev/6609672cfd7feaeb3982 to your computer and use it in GitHub Desktop.
Plot 2D histogram and its contour
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # This is an example of input parameters | |
| # the first plot starts here | |
| plot1 = {} | |
| # Path to the data file | |
| plot1['data_file'] = 'bayes1.dat' | |
| # which columns to plot? | |
| plot1['col1'] = 2 | |
| plot1['col2'] = 3 | |
| # title and labels | |
| plot1['label_x'] = 'Concentration' | |
| plot1['label_y'] = 'M200' | |
| plot1['title'] = 'redshift = 0.2' | |
| # reference points | |
| plot1['ref_points'] = [ (5, 2e15) ] | |
| # number of bins | |
| plot1['bins'] = 100 | |
| # the second plot starts here | |
| plot2 = {} | |
| plot2['data_file'] = 'bayes2.dat' | |
| plot2['col1'] = 2 | |
| plot2['col2'] = 3 | |
| plot2['label_x'] = 'Concentration' | |
| plot2['label_y'] = 'M200' | |
| plot2['title'] = 'redshift = 0.4' | |
| plot2['bins'] = 100 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python2 | |
| # plot the 2D histogram of the data with contours | |
| from __future__ import print_function | |
| import matplotlib.pyplot as plt | |
| from matplotlib.ticker import ScalarFormatter | |
| import matplotlib as mpl | |
| import numpy as np | |
| import sys | |
| import types | |
| from scipy.stats import gaussian_kde | |
| class Color(object): | |
| current_color = 'r' | |
| @staticmethod | |
| def get_current_color(): | |
| return Color.current_color | |
| @staticmethod | |
| def get_next_color(): | |
| if Color.current_color == 'r': | |
| Color.current_color = 'b' | |
| elif Color.current_color == 'b': | |
| Color.current_color = 'k' | |
| elif Color.current_color == 'k': | |
| Color.current_color = 'r' | |
| return Color.current_color | |
| @staticmethod | |
| def reset_color(): | |
| Color.current_color = 'r' | |
| # Use LaTeX rendering | |
| mpl.rc('text', usetex=True) | |
| all_config = {} | |
| _ = {} | |
| exec(sys.stdin, _, all_config) | |
| del _ | |
| plots_config = all_config['plots'] | |
| levels = all_config.get('levels', [0.68, 0.95, 0.997]) | |
| num_plots = len(plots_config) | |
| num_plots_x = 0 | |
| num_plots_y = 0 | |
| if num_plots % 2 == 0 and num_plots > 2: | |
| num_plots_y = num_plots / 2 | |
| num_plots_x = 2 | |
| else: | |
| num_plots_y = num_plots | |
| num_plots_x = 1 | |
| fig = plt.figure(figsize = (18, 11), dpi = 100) | |
| # The maximum limit of all figures | |
| max_lim = {} | |
| max_lim['x_low'] = all_config.get('x_low', float('Inf')) | |
| max_lim['x_high'] = all_config.get('x_high', -float('Inf')) | |
| max_lim['y_low'] = all_config.get('y_low', float('Inf')) | |
| max_lim['y_high'] = all_config.get('y_high', -float('Inf')) | |
| axes = [] # All the subplots | |
| for v in range(len(plots_config)): | |
| config = plots_config[v] | |
| ax = fig.add_subplot(num_plots_x, num_plots_y, v + 1, title = config.get('title', '')) | |
| axes.append(ax) | |
| ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False)) | |
| ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False)) | |
| ax.set_xlabel(config.get('label_x', '')) | |
| ax.set_ylabel(config.get('label_y', '')) | |
| col1 = config.get('col1', 0) | |
| col2 = config.get('col2', 1) | |
| # plot text | |
| ax.text(0.9, 0.9, config.get('text', ''), ha = 'right', va = 'top', transform = ax.transAxes) | |
| # get the list of data files if one need to plot multiple contours on the same plot | |
| data_files = config['data_file'] | |
| if isinstance(data_files, types.StringTypes): | |
| data_files = [data_files] | |
| data_function1 = config.get('data_function1', None) | |
| data_function2 = config.get('data_function2', None) | |
| lines = [] | |
| for data_file in data_files: | |
| num_levels = np.zeros(len(levels)) | |
| d = np.loadtxt(data_file) | |
| # call functions on data | |
| if data_function1 != None: | |
| d[:, col1] = map(data_function1, d[:, col1]) | |
| if data_function2 != None: | |
| d[:, col2] = map(data_function2, d[:, col2]) | |
| c1 = d[:, col1] | |
| c2 = d[:, col2] | |
| # seq = range(len(c1)) | |
| # Plot the data points if the user asks for it | |
| if config.get('plot_points', False): | |
| img = plt.scatter(c1, c2, s = 5, marker = 'x') #, cmap = mpl.cm.afmhot, c = seq) | |
| # cb = fig.colorbar(img, ticks = [0, len(c1)]) | |
| h, xedges, yedges = np.histogram2d(c1, c2, bins=config.get('bins', 100), normed=True) | |
| h = h * (yedges[1] - yedges[0]) * (xedges[1] - xedges[0]) | |
| xcenters = xedges[:-1] + 0.5 * (xedges[1:] - xedges[:-1]) | |
| ycenters = yedges[:-1] + 0.5 * (yedges[1:] - yedges[:-1]) | |
| # genearate the kernel | |
| X, Y = np.meshgrid(xcenters, ycenters) | |
| positions = np.vstack([X.ravel(), Y.ravel()]) | |
| kernel = gaussian_kde(np.vstack([c1, c2])) | |
| Z = kernel(positions) * ((yedges[1] - yedges[0]) * (xedges[1] - xedges[0])) | |
| # sort the probability density to calculate the cumulative probability | |
| indices = np.argsort(Z, axis = None)[::-1] | |
| for i in range(1, len(indices)): | |
| Z[indices[i]] = Z[indices[i]] + Z[indices[i - 1]] | |
| Z = np.reshape(Z, X.shape).T | |
| # plot the best fit point | |
| best_fit_pnt = np.unravel_index(indices[0], Z.shape) # find the best fit point | |
| plt.scatter(xcenters[best_fit_pnt[0]], ycenters[best_fit_pnt[1]], c = Color.get_current_color(), s = 30) | |
| print('Best fit point is ({0},{1})'.format(xcenters[best_fit_pnt[0]], ycenters[best_fit_pnt[1]]), file=sys.stderr) | |
| # plot the contour | |
| con = plt.contour(xcenters, ycenters, Z, levels = levels, colors = Color.get_current_color()) | |
| plt.clabel(con, inline=1, fontsize=10) | |
| lines.append(con.collections[0]) | |
| # check the axis limit | |
| x_low, x_high = ax.get_xlim() | |
| if max_lim['x_low'] > x_low: | |
| max_lim['x_low'] = x_low | |
| if max_lim['x_high'] < x_high: | |
| max_lim['x_high'] = x_high | |
| y_low, y_high = ax.get_ylim() | |
| if max_lim['y_low'] > y_low: | |
| max_lim['y_low'] = y_low | |
| if max_lim['y_high'] < y_high: | |
| max_lim['y_high'] = y_high | |
| # count the number of points on different levels | |
| flat_h = np.ravel(h) | |
| flat_Z = np.ravel(Z) | |
| for i in range(len(flat_Z)): | |
| for j in range(len(levels)): | |
| if flat_Z[i] < levels[j]: | |
| num_levels[j] = num_levels[j] + flat_h[i] | |
| for i in range(len(levels)): | |
| print('Level ' + str(levels[i]) + ' contains ' + str(num_levels[i]) + ' of points', file=sys.stderr) | |
| print('---', file=sys.stderr) | |
| Color.get_next_color() | |
| Color.reset_color() | |
| print('------------', file=sys.stderr) | |
| # plots the ref_points | |
| ref_points = config.get('ref_points', None) | |
| if data_function1 != None: | |
| if ref_points != None: | |
| for p in ref_points: | |
| p[0] = data_function1(p[0]) | |
| if data_function2 != None: | |
| if ref_points != None: | |
| for p in ref_points: | |
| p[1] = data_function2(p[1]) | |
| # plot ref points | |
| if ref_points != None: | |
| zipped_points = zip(*ref_points) | |
| plt.scatter(zipped_points[0], zipped_points[1], c = 'r', marker = '*', s = 30) | |
| # plot the legend | |
| if config.get('legend', None) != None: | |
| plt.legend(lines, config.get('legend'), loc = 'lower left') | |
| if all_config.get('same_range', False): | |
| print('Using the same axis range for all plots', file=sys.stderr) | |
| for ax in axes: | |
| ax.set_xlim(max_lim['x_low'], max_lim['x_high']) | |
| ax.set_ylim(max_lim['y_low'], max_lim['y_high']) | |
| fig.suptitle(all_config.get('main_title', ''), fontsize = 20) | |
| if len(sys.argv) > 1: | |
| plt.savefig(sys.argv[1]) | |
| else: | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment