Skip to content

Instantly share code, notes, and snippets.

@egpbos
Created February 23, 2016 09:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save egpbos/2d5733b4a921a7c6821b to your computer and use it in GitHub Desktop.
Save egpbos/2d5733b4a921a7c6821b to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import collections
import networkx as nx
def subplot_axes(rows, cols, textwidth=None, subplot_aspect=5./4, dpi=200):
"""
`subplot_aspect` is the horizontal divided by the vertical size of a
subplot.
"""
if textwidth is None:
textwidth = 10. # sensible matplotlib ipython default
fig = plt.figure(figsize=(textwidth,
float(textwidth) / cols * rows / subplot_aspect),
dpi=dpi)
# ax = ImageGrid(fig, 111, # similar to subplot(111)
# nrows_ncols = (2, 2),
# axes_pad = 0.1,
# add_all=True,
# label_mode = "L",
# )
ax = []
for i in range(rows * cols):
ax.append(fig.add_subplot(rows, cols, i+1))
return fig, ax
def slab_plot(slab, axes=None, title="", vmin=None,
vmax=None, rel_smoothing_scale=None, boxsize=None,
colorbar=True, fix_colorbar=True, v_ratio=False, cmap=None,
norm=None, log_cbar=False, vmid=None,
xlabel="Mpc/h", ylabel="Mpc/h", norm_clip=True,
cut_x=None, cut_y=None, alpha=1., return_objects=False,
cblabel=None, **kwargs):
"""
cut_x, cut_y: tuple in Mpc/h
"""
if axes is None:
fig = plt.figure()
axes = fig.add_subplot(111)
vmin, vmax = vminmax_grid_slice(slab, vmin, vmax, v_ratio)
gridsize_x = slab.shape[1]
if cut_x is not None:
x_slice = slice(int(np.floor(cut_x[0]/boxsize * gridsize_x)),
int(np.ceil(cut_x[1]/boxsize * gridsize_x)))
else:
x_slice = slice(0, gridsize_x)
gridsize_y = slab.shape[0]
if cut_y is not None:
y_slice = slice(int(np.floor(cut_y[0]/boxsize * gridsize_y)),
int(np.ceil(cut_y[1]/boxsize * gridsize_y)))
else:
y_slice = slice(0, gridsize_y)
extent = None
if boxsize is not None:
if cut_x is not None:
x_extent = (x_slice.start * boxsize / gridsize_x,
x_slice.stop * boxsize / gridsize_x)
else:
x_extent = (0, boxsize)
if cut_y is not None:
y_extent = (y_slice.start * boxsize / gridsize_y,
y_slice.stop * boxsize / gridsize_y)
else:
y_extent = (0, boxsize)
extent = [x_extent[0], x_extent[1], y_extent[0], y_extent[1]]
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
norm = determine_norm(norm, vmin, vmax, cmap, log_cbar, vmid=vmid,
clip=norm_clip)
if not log_cbar:
cax = axes.imshow(slab[y_slice, x_slice],
extent=extent, cmap=cmap, norm=norm, alpha=alpha,
**kwargs)
else:
cax = axes.imshow(np.log10(2+slab[y_slice, x_slice]),
extent=extent, cmap=cmap, norm=norm, alpha=alpha,
**kwargs)
axes.set_title(title)
if not boxsize:
axes.set_axis_off()
if colorbar:
if fix_colorbar:
divider = make_axes_locatable(axes) # Don't use plt.gca() instead of axes!
cax_cb = divider.append_axes("right", "5%", pad="3%")
cbar = axes.figure.colorbar(cax, cax=cax_cb)
else:
cbar = axes.figure.colorbar(cax)
cbar.set_clim(vmin, vmax)
ticks, ticklabels = determine_ticks(norm, log_cbar)
cbar.set_ticks(ticks)
cbar.set_ticklabels(ticklabels)
if cblabel is not None:
cbar.set_label(cblabel)
if return_objects:
if colorbar:
return vmin, vmax, axes, cbar
else:
return vmin, vmax, axes
else:
return vmin, vmax
def minimized_id_slice(object_id_grid, slice_index=0):
"""
Remaps the IDs of a slice of an `object_id_grid` to a range of
numbers that is as long as the actual number of unique objects
in the slice, from 0 to num-1. It does this in a random order.
-1 cells stay -1; this is used to signify that these cells have
different object type.
"""
object_slice = object_id_grid[slice_index]
unique_objects = np.unique(object_slice)
# initialize reverse lookup table object_id_range
# (1 extra item for the last id (.max() itself)):
object_id_range = np.arange(object_slice.max() + 1)
# note: no zero, since separate_object_IDs doesn't return zeros
# We do need to include it though. object_slice contains numbers in [1,max]
# and -1. When we use this to index, we won't hit the first element (the
# zero) in object_id_range. If we don't include it, all
# object_slice_shuffled assignments will be shifted downwards one element,
# which means it will pick out elements from object_id_range that are not
# populated (see below).
# new IDs (forget -1 for now), where zero is left out:
unique_id_ix = np.arange(1, len(unique_objects))
np.random.shuffle(unique_id_ix)
# actually populate the reverse lookup table with the new IDs:
object_id_range[np.in1d(object_id_range, unique_objects)] = unique_id_ix
# extend lookup table for -1 (cells of different type):
object_id_range = np.append(object_id_range, (-1,))
object_slice_shuffled = object_id_range[object_slice]
return object_slice_shuffled
def object_slice_graph(object_slice,
neighbour_cells_diff=neighbour_cells6_diff):
"""
Makes a graph out of a slice with object IDs. Every ID is a node
and for every cell of different ID that neighbours a cell of that
ID the graph has an edge. Only IDs above and including 0 are included.
"""
# assume square slice
gridsize = object_slice.shape[0]
graph = nx.Graph()
neigh2D = neighbour_cells_diff(2)
for ix, row in enumerate(object_slice):
for jx, cell in enumerate(row):
if cell >= 0:
graph.add_node(cell)
for d_nb in neigh2D:
nb_obj = object_slice[tuple((np.array((ix, jx)) + d_nb)
% gridsize)]
if nb_obj >= 0:
graph.add_edge(cell, nb_obj)
return graph
def minimum_graph_coloring(graph):
"""
Determines how to color the graph nodes so that no two neighbouring
nodes have the same color, but with a number of colors as small as
possible. For a planar graph (e.g. a land map) this algorithm should
give at most 6 colors. For more complicated graphs it may be more.
http://i.stanford.edu/pub/cstr/reports/cs/tr/80/830/CS-TR-80-830.pdf
Note: this function assumes that the graph contains a sequential range
of IDs, e.g. 0,1,2,3; not 1,4,5,7 (that's missing 0,2,3,6).
"""
# maximum degree
Dmax = max([graph.degree(node) for node in graph.nodes()])
# 1. Build degree lists
degree_lists = collections.defaultdict(set)
for node in graph.nodes():
degree_lists[graph.degree(node)].add(node)
# initialize stuff for step 2
N_nodes = len(graph.nodes())
designated = 0 # "i" in step 2 of the above paper
nodes_ordered = np.zeros(N_nodes, dtype='int32')
# 2. Label vertices (nodes) smallest degree last
while designated < N_nodes:
for j in range(Dmax + 1):
try:
node = degree_lists[j].pop()
break
except KeyError:
continue
designated += 1
nodes_ordered[N_nodes - designated] = node
nbs = graph[node]
for nb in nbs:
for j_nb in range(j, Dmax + 1):
try:
degree_lists[j_nb].remove(nb)
degree_lists[j_nb - 1].add(nb)
break
except KeyError:
continue
# initialize for step 3
node_color = np.zeros(N_nodes + 2, dtype='int32')
# 3. Color vertices
for node in nodes_ordered:
nbs = graph[node]
highest_adjacent_color = max([0, ] + [node_color[nb] for nb in nbs])
node_color[node] = highest_adjacent_color + 1
node_color[-1] = -1
return node_color
def minimum_slice_coloring(object_id_grid, slice_index=0, coloring_tries=3):
# coloring (multiple tries, keep the smallest)
color_tries = []
for ix in range(coloring_tries):
slice_shuffled = minimized_id_slice(object_id_grid,
slice_index=slice_index)
graph = object_slice_graph(slice_shuffled)
grid_colors = minimum_graph_coloring(graph)
color_tries.append({'slice': slice_shuffled, 'colors': grid_colors})
ix_least_colors = np.argmin([color['colors'].max() for color in
color_tries])
slice_shuffled = color_tries[ix_least_colors]['slice']
grid_colors = color_tries[ix_least_colors]['colors']
return slice_shuffled, grid_colors
def plot_object_id_slab(slice_shuffled, grid_colors, boxsize, slice_index=0,
ax=None, textwidth=17, coloring_tries=3, bg_color='k',
cut_x=None, cut_y=None, cmap_large='husl', alpha=1,
desat=0.6, xlabel="Mpc/h", ylabel="Mpc/h",
force_cmap_large=False):
N_colors = grid_colors.max()
pal = matplotlib.colors.ListedColormap(sns.color_palette("Set1", n_colors=N_colors, desat=desat))
if N_colors > 9 or force_cmap_large:
print ("Number of colors is {0}, Set1 palette only has 9 colors. " +
"Changed palette to {1}.").format(N_colors, cmap_large)
pal = matplotlib.colors.ListedColormap(sns.color_palette(cmap_large,
N_colors))
pal.set_under(color=bg_color)
if ax is None:
fig, ax = subplot_axes(rows=1, columns=1, textwidth=textwidth)
slab_plot(grid_colors[slice_shuffled], boxsize=boxsize, axes=ax,
cmap=pal, vmin=0.5, vmax=N_colors + 0.5, colorbar=False,
cut_x=cut_x, cut_y=cut_y, alpha=alpha, xlabel=xlabel,
ylabel=ylabel)
return ax
def plot_object_id_grid(object_id_grid, boxsize, slice_index=0, ax=None,
textwidth=17, coloring_tries=3, bg_color='k'):
slice_shuffled, grid_colors = minimum_slice_coloring(object_id_grid, slice_index=slice_index,
coloring_tries=coloring_tries)
return plot_object_id_slab(slice_shuffled, grid_colors, boxsize,
slice_index=slice_index,
ax=ax, textwidth=textwidth,
coloring_tries=coloring_tries,
bg_color=bg_color)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment