Created
February 23, 2016 09:49
-
-
Save egpbos/2d5733b4a921a7c6821b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import numpy as np | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import collections | |
import networkx as nx | |
def subplot_axes(rows, cols, textwidth=None, subplot_aspect=5./4, dpi=200): | |
""" | |
`subplot_aspect` is the horizontal divided by the vertical size of a | |
subplot. | |
""" | |
if textwidth is None: | |
textwidth = 10. # sensible matplotlib ipython default | |
fig = plt.figure(figsize=(textwidth, | |
float(textwidth) / cols * rows / subplot_aspect), | |
dpi=dpi) | |
# ax = ImageGrid(fig, 111, # similar to subplot(111) | |
# nrows_ncols = (2, 2), | |
# axes_pad = 0.1, | |
# add_all=True, | |
# label_mode = "L", | |
# ) | |
ax = [] | |
for i in range(rows * cols): | |
ax.append(fig.add_subplot(rows, cols, i+1)) | |
return fig, ax | |
def slab_plot(slab, axes=None, title="", vmin=None, | |
vmax=None, rel_smoothing_scale=None, boxsize=None, | |
colorbar=True, fix_colorbar=True, v_ratio=False, cmap=None, | |
norm=None, log_cbar=False, vmid=None, | |
xlabel="Mpc/h", ylabel="Mpc/h", norm_clip=True, | |
cut_x=None, cut_y=None, alpha=1., return_objects=False, | |
cblabel=None, **kwargs): | |
""" | |
cut_x, cut_y: tuple in Mpc/h | |
""" | |
if axes is None: | |
fig = plt.figure() | |
axes = fig.add_subplot(111) | |
vmin, vmax = vminmax_grid_slice(slab, vmin, vmax, v_ratio) | |
gridsize_x = slab.shape[1] | |
if cut_x is not None: | |
x_slice = slice(int(np.floor(cut_x[0]/boxsize * gridsize_x)), | |
int(np.ceil(cut_x[1]/boxsize * gridsize_x))) | |
else: | |
x_slice = slice(0, gridsize_x) | |
gridsize_y = slab.shape[0] | |
if cut_y is not None: | |
y_slice = slice(int(np.floor(cut_y[0]/boxsize * gridsize_y)), | |
int(np.ceil(cut_y[1]/boxsize * gridsize_y))) | |
else: | |
y_slice = slice(0, gridsize_y) | |
extent = None | |
if boxsize is not None: | |
if cut_x is not None: | |
x_extent = (x_slice.start * boxsize / gridsize_x, | |
x_slice.stop * boxsize / gridsize_x) | |
else: | |
x_extent = (0, boxsize) | |
if cut_y is not None: | |
y_extent = (y_slice.start * boxsize / gridsize_y, | |
y_slice.stop * boxsize / gridsize_y) | |
else: | |
y_extent = (0, boxsize) | |
extent = [x_extent[0], x_extent[1], y_extent[0], y_extent[1]] | |
axes.set_xlabel(xlabel) | |
axes.set_ylabel(ylabel) | |
norm = determine_norm(norm, vmin, vmax, cmap, log_cbar, vmid=vmid, | |
clip=norm_clip) | |
if not log_cbar: | |
cax = axes.imshow(slab[y_slice, x_slice], | |
extent=extent, cmap=cmap, norm=norm, alpha=alpha, | |
**kwargs) | |
else: | |
cax = axes.imshow(np.log10(2+slab[y_slice, x_slice]), | |
extent=extent, cmap=cmap, norm=norm, alpha=alpha, | |
**kwargs) | |
axes.set_title(title) | |
if not boxsize: | |
axes.set_axis_off() | |
if colorbar: | |
if fix_colorbar: | |
divider = make_axes_locatable(axes) # Don't use plt.gca() instead of axes! | |
cax_cb = divider.append_axes("right", "5%", pad="3%") | |
cbar = axes.figure.colorbar(cax, cax=cax_cb) | |
else: | |
cbar = axes.figure.colorbar(cax) | |
cbar.set_clim(vmin, vmax) | |
ticks, ticklabels = determine_ticks(norm, log_cbar) | |
cbar.set_ticks(ticks) | |
cbar.set_ticklabels(ticklabels) | |
if cblabel is not None: | |
cbar.set_label(cblabel) | |
if return_objects: | |
if colorbar: | |
return vmin, vmax, axes, cbar | |
else: | |
return vmin, vmax, axes | |
else: | |
return vmin, vmax | |
def minimized_id_slice(object_id_grid, slice_index=0): | |
""" | |
Remaps the IDs of a slice of an `object_id_grid` to a range of | |
numbers that is as long as the actual number of unique objects | |
in the slice, from 0 to num-1. It does this in a random order. | |
-1 cells stay -1; this is used to signify that these cells have | |
different object type. | |
""" | |
object_slice = object_id_grid[slice_index] | |
unique_objects = np.unique(object_slice) | |
# initialize reverse lookup table object_id_range | |
# (1 extra item for the last id (.max() itself)): | |
object_id_range = np.arange(object_slice.max() + 1) | |
# note: no zero, since separate_object_IDs doesn't return zeros | |
# We do need to include it though. object_slice contains numbers in [1,max] | |
# and -1. When we use this to index, we won't hit the first element (the | |
# zero) in object_id_range. If we don't include it, all | |
# object_slice_shuffled assignments will be shifted downwards one element, | |
# which means it will pick out elements from object_id_range that are not | |
# populated (see below). | |
# new IDs (forget -1 for now), where zero is left out: | |
unique_id_ix = np.arange(1, len(unique_objects)) | |
np.random.shuffle(unique_id_ix) | |
# actually populate the reverse lookup table with the new IDs: | |
object_id_range[np.in1d(object_id_range, unique_objects)] = unique_id_ix | |
# extend lookup table for -1 (cells of different type): | |
object_id_range = np.append(object_id_range, (-1,)) | |
object_slice_shuffled = object_id_range[object_slice] | |
return object_slice_shuffled | |
def object_slice_graph(object_slice, | |
neighbour_cells_diff=neighbour_cells6_diff): | |
""" | |
Makes a graph out of a slice with object IDs. Every ID is a node | |
and for every cell of different ID that neighbours a cell of that | |
ID the graph has an edge. Only IDs above and including 0 are included. | |
""" | |
# assume square slice | |
gridsize = object_slice.shape[0] | |
graph = nx.Graph() | |
neigh2D = neighbour_cells_diff(2) | |
for ix, row in enumerate(object_slice): | |
for jx, cell in enumerate(row): | |
if cell >= 0: | |
graph.add_node(cell) | |
for d_nb in neigh2D: | |
nb_obj = object_slice[tuple((np.array((ix, jx)) + d_nb) | |
% gridsize)] | |
if nb_obj >= 0: | |
graph.add_edge(cell, nb_obj) | |
return graph | |
def minimum_graph_coloring(graph): | |
""" | |
Determines how to color the graph nodes so that no two neighbouring | |
nodes have the same color, but with a number of colors as small as | |
possible. For a planar graph (e.g. a land map) this algorithm should | |
give at most 6 colors. For more complicated graphs it may be more. | |
http://i.stanford.edu/pub/cstr/reports/cs/tr/80/830/CS-TR-80-830.pdf | |
Note: this function assumes that the graph contains a sequential range | |
of IDs, e.g. 0,1,2,3; not 1,4,5,7 (that's missing 0,2,3,6). | |
""" | |
# maximum degree | |
Dmax = max([graph.degree(node) for node in graph.nodes()]) | |
# 1. Build degree lists | |
degree_lists = collections.defaultdict(set) | |
for node in graph.nodes(): | |
degree_lists[graph.degree(node)].add(node) | |
# initialize stuff for step 2 | |
N_nodes = len(graph.nodes()) | |
designated = 0 # "i" in step 2 of the above paper | |
nodes_ordered = np.zeros(N_nodes, dtype='int32') | |
# 2. Label vertices (nodes) smallest degree last | |
while designated < N_nodes: | |
for j in range(Dmax + 1): | |
try: | |
node = degree_lists[j].pop() | |
break | |
except KeyError: | |
continue | |
designated += 1 | |
nodes_ordered[N_nodes - designated] = node | |
nbs = graph[node] | |
for nb in nbs: | |
for j_nb in range(j, Dmax + 1): | |
try: | |
degree_lists[j_nb].remove(nb) | |
degree_lists[j_nb - 1].add(nb) | |
break | |
except KeyError: | |
continue | |
# initialize for step 3 | |
node_color = np.zeros(N_nodes + 2, dtype='int32') | |
# 3. Color vertices | |
for node in nodes_ordered: | |
nbs = graph[node] | |
highest_adjacent_color = max([0, ] + [node_color[nb] for nb in nbs]) | |
node_color[node] = highest_adjacent_color + 1 | |
node_color[-1] = -1 | |
return node_color | |
def minimum_slice_coloring(object_id_grid, slice_index=0, coloring_tries=3): | |
# coloring (multiple tries, keep the smallest) | |
color_tries = [] | |
for ix in range(coloring_tries): | |
slice_shuffled = minimized_id_slice(object_id_grid, | |
slice_index=slice_index) | |
graph = object_slice_graph(slice_shuffled) | |
grid_colors = minimum_graph_coloring(graph) | |
color_tries.append({'slice': slice_shuffled, 'colors': grid_colors}) | |
ix_least_colors = np.argmin([color['colors'].max() for color in | |
color_tries]) | |
slice_shuffled = color_tries[ix_least_colors]['slice'] | |
grid_colors = color_tries[ix_least_colors]['colors'] | |
return slice_shuffled, grid_colors | |
def plot_object_id_slab(slice_shuffled, grid_colors, boxsize, slice_index=0, | |
ax=None, textwidth=17, coloring_tries=3, bg_color='k', | |
cut_x=None, cut_y=None, cmap_large='husl', alpha=1, | |
desat=0.6, xlabel="Mpc/h", ylabel="Mpc/h", | |
force_cmap_large=False): | |
N_colors = grid_colors.max() | |
pal = matplotlib.colors.ListedColormap(sns.color_palette("Set1", n_colors=N_colors, desat=desat)) | |
if N_colors > 9 or force_cmap_large: | |
print ("Number of colors is {0}, Set1 palette only has 9 colors. " + | |
"Changed palette to {1}.").format(N_colors, cmap_large) | |
pal = matplotlib.colors.ListedColormap(sns.color_palette(cmap_large, | |
N_colors)) | |
pal.set_under(color=bg_color) | |
if ax is None: | |
fig, ax = subplot_axes(rows=1, columns=1, textwidth=textwidth) | |
slab_plot(grid_colors[slice_shuffled], boxsize=boxsize, axes=ax, | |
cmap=pal, vmin=0.5, vmax=N_colors + 0.5, colorbar=False, | |
cut_x=cut_x, cut_y=cut_y, alpha=alpha, xlabel=xlabel, | |
ylabel=ylabel) | |
return ax | |
def plot_object_id_grid(object_id_grid, boxsize, slice_index=0, ax=None, | |
textwidth=17, coloring_tries=3, bg_color='k'): | |
slice_shuffled, grid_colors = minimum_slice_coloring(object_id_grid, slice_index=slice_index, | |
coloring_tries=coloring_tries) | |
return plot_object_id_slab(slice_shuffled, grid_colors, boxsize, | |
slice_index=slice_index, | |
ax=ax, textwidth=textwidth, | |
coloring_tries=coloring_tries, | |
bg_color=bg_color) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment