Created
November 11, 2010 17:26
-
-
Save astrofrog/672850 to your computer and use it in GitHub Desktop.
fun with astronomical dendrograms (test code)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import getopt | |
import pyfits | |
import numpy as np | |
import matplotlib | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
from matplotlib.collections import LineCollection | |
from meshgrid import meshgrid_nd | |
IDX_COUNTER = 0 | |
def next_idx(): | |
global IDX_COUNTER | |
IDX_COUNTER += 1 | |
return IDX_COUNTER | |
try: | |
opts, args = getopt.getopt(sys.argv[1:-1], "m:o:n:d:", ["minimum=", "output=", "min_n_pix=", "min_delta="]) | |
except getopt.GetoptError, err: | |
print str(err) | |
sys.exit(2) | |
filename = sys.argv[-1] | |
PAR = {} | |
PAR['minimum'] = -np.inf | |
PAR['output'] = filename.replace('.fits', '').replace('.gz', '') + "_dendrogram.fits" | |
PAR['npix'] = 0 | |
PAR['delta'] = 0. | |
PAR['format'] = 'fits' | |
for o, a in opts: | |
if o in ("-m", "--minimum"): | |
PAR['minimum'] = float(a) | |
elif o in ("-o", "--output"): | |
PAR['output'] = a | |
elif o in ("-n", "--min_n_pix"): | |
PAR['npix'] = int(a) | |
elif o in ("-d", "--min_delta"): | |
PAR['delta'] = float(a) | |
else: | |
assert False, "unhandled option" | |
class Leaf(object): | |
def __init__(self, x, y, z, f): | |
self.x = np.array([x], dtype=int) | |
self.y = np.array([y], dtype=int) | |
self.z = np.array([z], dtype=int) | |
self.f = np.array([f], dtype=float) | |
self.xmin, self.xmax = x, x | |
self.ymin, self.ymax = y, y | |
self.zmin, self.zmax = z, z | |
self.fmin, self.fmax = f, f | |
self.parent = None | |
def __getattr__(self, attribute): | |
if attribute == 'npix': | |
return len(self.x) | |
else: | |
raise Exception("Attribute not found: %s" % attribute) | |
def add_point(self, x, y, z, f): | |
"Add point to current leaf" | |
self.x = np.hstack([self.x, x]) | |
self.y = np.hstack([self.y, y]) | |
self.z = np.hstack([self.z, z]) | |
self.f = np.hstack([self.f, f]) | |
self.xmin, self.xmax = min(x, self.xmin), max(x, self.xmax) | |
self.ymin, self.ymax = min(y, self.ymin), max(y, self.ymax) | |
self.zmin, self.zmax = min(z, self.zmin), max(z, self.zmax) | |
self.fmin, self.fmax = min(f, self.fmin), max(f, self.fmax) | |
def merge(self, leaf): | |
self.x = np.hstack([self.x, leaf.x]) | |
self.y = np.hstack([self.y, leaf.y]) | |
self.z = np.hstack([self.z, leaf.z]) | |
self.f = np.hstack([self.f, leaf.f]) | |
self.xmin, self.xmax = min(np.min(leaf.x), self.xmin), max(np.max(leaf.x), self.xmax) | |
self.ymin, self.ymax = min(np.min(leaf.y), self.ymin), max(np.max(leaf.y), self.ymax) | |
self.zmin, self.zmax = min(np.min(leaf.z), self.zmin), max(np.max(leaf.z), self.zmax) | |
self.fmin, self.fmax = min(np.min(leaf.f), self.fmin), max(np.max(leaf.f), self.fmax) | |
def add_footprint(self, image, level): | |
"Fill in a map which shows the depth of the tree" | |
image[self.z, self.y, self.x] = level | |
return image | |
def plot_dendrogram(self, ax, base_level, lines): | |
line = [(self.id, np.max(self.f)), (self.id, base_level)] | |
lines.append(line) | |
return lines | |
def set_id(self, leaf_id): | |
self.id = leaf_id | |
return leaf_id + 1 | |
class Branch(Leaf): | |
def __init__(self, items, x, y, z, f): | |
self.items = items | |
for item in items: | |
item.parent = self | |
Leaf.__init__(self, x, y, z, f) | |
def __getattr__(self, attribute): | |
if attribute == 'npix': | |
npix = len(self.x) | |
for item in self.items: | |
npix += item.npix | |
return npix | |
else: | |
raise AttributeError("Attribute not found: %s" % attribute) | |
def add_footprint(self, image, level): | |
for item in self.items: | |
image = item.add_footprint(image, level + 1) | |
return Leaf.add_footprint(self, image, level) | |
def plot_dendrogram(self, ax, base_level, lines): | |
line = [(self.id, np.min(self.f)), (self.id, base_level)] | |
lines.append(line) | |
items_ids = [item.id for item in self.items] | |
line = [(np.min(items_ids), np.min(self.f)), \ | |
(np.max(items_ids), np.min(self.f))] | |
lines.append(line) | |
for item in self.items: | |
lines = item.plot_dendrogram(ax, np.min(self.f), lines) | |
return lines | |
def set_id(self, start): | |
item_id = start | |
for item in self.items: | |
if not hasattr(self, 'id'): | |
item_id = item.set_id(item_id) | |
self.id = np.mean([item.id for item in self.items]) | |
return item_id | |
class Trunk(list): | |
pass | |
ancestor = {} | |
# Read in data | |
data = pyfits.getdata(filename) | |
if len(data.shape) == 2: | |
data = data.reshape(1, data.shape[1], data.shape[0]) | |
data[np.isnan(data)] = 0. | |
nz, ny, nx = data.shape | |
# Create arrays with pixel positions | |
x = np.arange(data.shape[2]) | |
y = np.arange(data.shape[1]) | |
z = np.arange(data.shape[0]) | |
X, Y, Z = meshgrid_nd(x, y, z) | |
# Convert to 1D | |
flux, X, Y, Z = data.ravel(), X.ravel(), Y.ravel(), Z.ravel() | |
# Sort | |
order = np.argsort(flux) | |
flux, X, Y, Z = flux[order], X[order], Y[order], Z[order] | |
# Reverse | |
flux, X, Y, Z = flux[::-1], X[::-1], Y[::-1], Z[::-1] | |
# Define index of what item each cell is part of | |
index = np.zeros(data.shape, dtype=int) | |
print "Number of points: %i" % np.sum(flux > PAR['minimum']) | |
# Loop from largest to smallest value. Each time, check if the pixel | |
# connects to any existing cluster. Otherwise, create new cluster. | |
items = {} | |
for i in range(len(flux)): | |
if i % 10000 == 0: | |
print "%i..." % i | |
# Don't want negative values | |
if flux[i] <= PAR['minimum']: | |
break | |
# Check if point is adjacent to any leaf | |
adjacent = [] | |
if X[i] > 0 and index[Z[i], Y[i], X[i] - 1] > 0: | |
adjacent.append(index[Z[i], Y[i], X[i] - 1]) | |
if X[i] < nx - 1 and index[Z[i], Y[i], X[i] + 1] > 0: | |
adjacent.append(index[Z[i], Y[i], X[i] + 1]) | |
if Y[i] > 0 and index[Z[i], Y[i] - 1, X[i]] > 0: | |
adjacent.append(index[Z[i], Y[i] - 1, X[i]]) | |
if Y[i] < ny - 1 and index[Z[i], Y[i] + 1, X[i]] > 0: | |
adjacent.append(index[Z[i], Y[i] + 1, X[i]]) | |
if Z[i] > 0 and index[Z[i] - 1, Y[i], X[i]] > 0: | |
adjacent.append(index[Z[i] - 1, Y[i], X[i]]) | |
if Z[i] < nz - 1 and index[Z[i] + 1, Y[i], X[i]] > 0: | |
adjacent.append(index[Z[i] + 1, Y[i], X[i]]) | |
for j in range(len(adjacent)): | |
if ancestor[adjacent[j]] is not None: | |
adjacent[j] = ancestor[adjacent[j]] | |
adjacent = list(set(adjacent)) | |
n_adjacent = len(adjacent) | |
if n_adjacent == 0: # Create new leaf | |
# Set absolute index of the new element | |
idx = next_idx() | |
# Create leaf | |
leaf = Leaf(X[i], Y[i], Z[i], flux[i]) | |
# Add leaf to overall list | |
items[idx] = leaf | |
# Set absolute index of pixel in index map | |
index[Z[i], Y[i], X[i]] = idx | |
# Create new entry for ancestor | |
ancestor[idx] = None | |
elif n_adjacent == 1: # Add to existing leaf or branch | |
# Get absolute index of adjacent element | |
idx = adjacent[0] | |
# Get adjacent item | |
item = items[idx] | |
# Add point to item | |
item.add_point(X[i], Y[i], Z[i], flux[i]) | |
# Set absolute index of pixel in index map | |
index[Z[i], Y[i], X[i]] = idx | |
else: # Merge leaves | |
# At this stage, the adjacent items might consist of an arbitrary | |
# number of leaves and branches. | |
# Find all leaves that are not important enough to be kept separate | |
merge = [] | |
for idx in adjacent: | |
if type(items[idx]) == Leaf: | |
leaf = items[idx] | |
if leaf.npix < PAR['npix'] or leaf.fmax - flux[i] < PAR['delta']: | |
merge.append(idx) | |
# Remove merges from list of adjacent items | |
for idx in merge: | |
adjacent.remove(idx) | |
# If there is only one item left, then if it is a leaf, add to the | |
# list to merge, and if it is a branch then add the merges to the | |
# branch. | |
if len(adjacent) == 0: | |
# There are no separate leaves left (and no branches), so pick the | |
# first one as the reference and merge all the others onto it | |
idx = merge[0] | |
leaf = items[idx] | |
# Add current point to the leaf | |
leaf.add_point(X[i], Y[i], Z[i], flux[i]) | |
# Set absolute index of pixel in index map | |
index[Z[i], Y[i], X[i]] = idx | |
for i in merge[1:]: | |
# print "Merging leaf %i onto leaf %i" % (i, idx) | |
# Remove leaf | |
removed = items.pop(i) | |
# Merge old leaf onto reference leaf | |
leaf.merge(removed) | |
# Update index map | |
index = removed.add_footprint(index, idx) | |
elif len(adjacent) == 1: | |
if type(items[adjacent[0]]) == Leaf: | |
idx = adjacent[0] | |
leaf = items[idx] | |
# Add current point to the leaf | |
leaf.add_point(X[i], Y[i], Z[i], flux[i]) | |
# Set absolute index of pixel in index map | |
index[Z[i], Y[i], X[i]] = idx | |
for i in merge: | |
# print "Merging leaf %i onto leaf %i" % (i, idx) | |
# Remove leaf | |
removed = items.pop(i) | |
# Merge old leaf onto reference leaf | |
leaf.merge(removed) | |
# Update index map | |
index = removed.add_footprint(index, idx) | |
else: | |
idx = adjacent[0] | |
branch = items[idx] | |
# Add current point to the branch | |
branch.add_point(X[i], Y[i], Z[i], flux[i]) | |
# Set absolute index of pixel in index map | |
index[Z[i], Y[i], X[i]] = idx | |
for i in merge: | |
# print "Merging leaf %i onto branch %i" % (i, idx) | |
# Remove leaf | |
removed = items.pop(i) | |
# Merge old leaf onto reference leaf | |
branch.merge(removed) | |
# Update index map | |
index = removed.add_footprint(index, idx) | |
else: | |
# Set absolute index of the new element | |
idx = next_idx() | |
# Create branch | |
branch = Branch([items[j] for j in adjacent], \ | |
X[i], Y[i], Z[i], flux[i]) | |
# Add branch to overall list | |
items[idx] = branch | |
# Set absolute index of pixel in index map | |
index[Z[i], Y[i], X[i]] = idx | |
# Create new entry for ancestor | |
ancestor[idx] = None | |
for i in merge: | |
# print "Merging leaf %i onto branch %i" % (i, idx) | |
# Remove leaf | |
removed = items.pop(i) | |
# Merge old leaf onto reference leaf | |
branch.merge(removed) | |
# Update index map | |
index = removed.add_footprint(index, idx) | |
for j in adjacent: | |
ancestor[j] = idx | |
for a in ancestor: | |
if ancestor[a] == j: | |
ancestor[a] = idx | |
# Remove orphan leaves that aren't large enough | |
remove = [] | |
for idx in items: | |
item = items[idx] | |
if type(item) == Leaf: | |
if item.npix < PAR['npix'] or item.fmax - item.fmin < PAR['delta']: | |
remove.append(idx) | |
for idx in remove: | |
items.pop(idx) | |
# Create trunk from objects with no ancestors | |
trunk = Trunk() | |
for idx in items: | |
if ancestor[idx] is None: | |
trunk.append(items[idx]) | |
leaf_id = 1 | |
for item in trunk: | |
leaf_id = item.set_id(leaf_id) | |
if PAR['format'] == 'fits': | |
import pyfits | |
# Create reverse dictionary | |
reverse = {} | |
for key in items: | |
reverse[items[key]] = key | |
leaves = [] | |
branches = {} | |
for idx in items: | |
if type(items[idx]) == Leaf: | |
leaves.append(idx) | |
else: | |
branches[idx] = [] | |
for idx in items: | |
if items[idx].parent is not None: | |
pidx = reverse[items[idx].parent] | |
branches[pidx].append(idx) | |
leaves_table = np.array(zip(leaves), dtype=[('id', int)]) | |
branches_table = np.zeros((len(branches)), dtype=[('id', int), ('children', int, 6)]) | |
for i, idx in enumerate(branches): | |
for j, cidx in enumerate(branches[idx]): | |
branches_table['id'][i] = idx | |
branches_table['children'][i, j] = cidx | |
hdu_index = pyfits.PrimaryHDU(data=index) | |
hdu_leaves = pyfits.BinTableHDU(data=leaves_table, name="Leaves") | |
hdu_branches = pyfits.BinTableHDU(data=branches_table, name="Branches") | |
hdulist = pyfits.HDUList([hdu_index, hdu_leaves, hdu_branches]) | |
hdulist.writeto(PAR['output']) | |
# # Compute level footprint | |
# footprint = np.zeros(data.shape) | |
# for item in trunk: | |
# footprint = item.add_footprint(footprint, 1) | |
# pyfits.writeto('%s' % PAR['output'], footprint, clobber=True) | |
# os.system('gzip --best %s' % PAR['output']) | |
# | |
# # Make a plot | |
# fig = plt.figure() | |
# ax = fig.add_subplot(1, 2, 1) | |
# ax.imshow(data[0, :, :], origin='lower', interpolation='nearest') | |
# ax.set_title("Original data") | |
# ax = fig.add_subplot(1, 2, 2) | |
# ax.imshow(footprint[0, :, :], origin='lower', interpolation='nearest') | |
# ax.set_title('Levels') | |
# fig.savefig('structure.png') | |
# | |
# # Make a plot | |
# fig = plt.figure(figsize=(200, 20)) | |
# ax = fig.add_subplot(1, 1, 1) | |
# lines = [] | |
# for item in trunk: | |
# lines = item.plot_dendrogram(ax, PAR['minimum'], lines) | |
# ax.add_collection(LineCollection(lines, lw=0.25)) | |
# ax.set_xlim(0., float(leaf_id + 1)) | |
# ax.set_ylim(max(PAR['minimum'], data.max() / 1000.), data.max()) | |
# ax.set_yscale('log') | |
# fig.savefig('dendrogram.pdf') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment