Skip to content

Instantly share code, notes, and snippets.

@astrofrog
Created November 11, 2010 17:26
Show Gist options
  • Save astrofrog/672850 to your computer and use it in GitHub Desktop.
Save astrofrog/672850 to your computer and use it in GitHub Desktop.
fun with astronomical dendrograms (test code)
import sys
import os
import getopt
import pyfits
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from meshgrid import meshgrid_nd
IDX_COUNTER = 0
def next_idx():
global IDX_COUNTER
IDX_COUNTER += 1
return IDX_COUNTER
try:
opts, args = getopt.getopt(sys.argv[1:-1], "m:o:n:d:", ["minimum=", "output=", "min_n_pix=", "min_delta="])
except getopt.GetoptError, err:
print str(err)
sys.exit(2)
filename = sys.argv[-1]
PAR = {}
PAR['minimum'] = -np.inf
PAR['output'] = filename.replace('.fits', '').replace('.gz', '') + "_dendrogram.fits"
PAR['npix'] = 0
PAR['delta'] = 0.
PAR['format'] = 'fits'
for o, a in opts:
if o in ("-m", "--minimum"):
PAR['minimum'] = float(a)
elif o in ("-o", "--output"):
PAR['output'] = a
elif o in ("-n", "--min_n_pix"):
PAR['npix'] = int(a)
elif o in ("-d", "--min_delta"):
PAR['delta'] = float(a)
else:
assert False, "unhandled option"
class Leaf(object):
def __init__(self, x, y, z, f):
self.x = np.array([x], dtype=int)
self.y = np.array([y], dtype=int)
self.z = np.array([z], dtype=int)
self.f = np.array([f], dtype=float)
self.xmin, self.xmax = x, x
self.ymin, self.ymax = y, y
self.zmin, self.zmax = z, z
self.fmin, self.fmax = f, f
self.parent = None
def __getattr__(self, attribute):
if attribute == 'npix':
return len(self.x)
else:
raise Exception("Attribute not found: %s" % attribute)
def add_point(self, x, y, z, f):
"Add point to current leaf"
self.x = np.hstack([self.x, x])
self.y = np.hstack([self.y, y])
self.z = np.hstack([self.z, z])
self.f = np.hstack([self.f, f])
self.xmin, self.xmax = min(x, self.xmin), max(x, self.xmax)
self.ymin, self.ymax = min(y, self.ymin), max(y, self.ymax)
self.zmin, self.zmax = min(z, self.zmin), max(z, self.zmax)
self.fmin, self.fmax = min(f, self.fmin), max(f, self.fmax)
def merge(self, leaf):
self.x = np.hstack([self.x, leaf.x])
self.y = np.hstack([self.y, leaf.y])
self.z = np.hstack([self.z, leaf.z])
self.f = np.hstack([self.f, leaf.f])
self.xmin, self.xmax = min(np.min(leaf.x), self.xmin), max(np.max(leaf.x), self.xmax)
self.ymin, self.ymax = min(np.min(leaf.y), self.ymin), max(np.max(leaf.y), self.ymax)
self.zmin, self.zmax = min(np.min(leaf.z), self.zmin), max(np.max(leaf.z), self.zmax)
self.fmin, self.fmax = min(np.min(leaf.f), self.fmin), max(np.max(leaf.f), self.fmax)
def add_footprint(self, image, level):
"Fill in a map which shows the depth of the tree"
image[self.z, self.y, self.x] = level
return image
def plot_dendrogram(self, ax, base_level, lines):
line = [(self.id, np.max(self.f)), (self.id, base_level)]
lines.append(line)
return lines
def set_id(self, leaf_id):
self.id = leaf_id
return leaf_id + 1
class Branch(Leaf):
def __init__(self, items, x, y, z, f):
self.items = items
for item in items:
item.parent = self
Leaf.__init__(self, x, y, z, f)
def __getattr__(self, attribute):
if attribute == 'npix':
npix = len(self.x)
for item in self.items:
npix += item.npix
return npix
else:
raise AttributeError("Attribute not found: %s" % attribute)
def add_footprint(self, image, level):
for item in self.items:
image = item.add_footprint(image, level + 1)
return Leaf.add_footprint(self, image, level)
def plot_dendrogram(self, ax, base_level, lines):
line = [(self.id, np.min(self.f)), (self.id, base_level)]
lines.append(line)
items_ids = [item.id for item in self.items]
line = [(np.min(items_ids), np.min(self.f)), \
(np.max(items_ids), np.min(self.f))]
lines.append(line)
for item in self.items:
lines = item.plot_dendrogram(ax, np.min(self.f), lines)
return lines
def set_id(self, start):
item_id = start
for item in self.items:
if not hasattr(self, 'id'):
item_id = item.set_id(item_id)
self.id = np.mean([item.id for item in self.items])
return item_id
class Trunk(list):
pass
ancestor = {}
# Read in data
data = pyfits.getdata(filename)
if len(data.shape) == 2:
data = data.reshape(1, data.shape[1], data.shape[0])
data[np.isnan(data)] = 0.
nz, ny, nx = data.shape
# Create arrays with pixel positions
x = np.arange(data.shape[2])
y = np.arange(data.shape[1])
z = np.arange(data.shape[0])
X, Y, Z = meshgrid_nd(x, y, z)
# Convert to 1D
flux, X, Y, Z = data.ravel(), X.ravel(), Y.ravel(), Z.ravel()
# Sort
order = np.argsort(flux)
flux, X, Y, Z = flux[order], X[order], Y[order], Z[order]
# Reverse
flux, X, Y, Z = flux[::-1], X[::-1], Y[::-1], Z[::-1]
# Define index of what item each cell is part of
index = np.zeros(data.shape, dtype=int)
print "Number of points: %i" % np.sum(flux > PAR['minimum'])
# Loop from largest to smallest value. Each time, check if the pixel
# connects to any existing cluster. Otherwise, create new cluster.
items = {}
for i in range(len(flux)):
if i % 10000 == 0:
print "%i..." % i
# Don't want negative values
if flux[i] <= PAR['minimum']:
break
# Check if point is adjacent to any leaf
adjacent = []
if X[i] > 0 and index[Z[i], Y[i], X[i] - 1] > 0:
adjacent.append(index[Z[i], Y[i], X[i] - 1])
if X[i] < nx - 1 and index[Z[i], Y[i], X[i] + 1] > 0:
adjacent.append(index[Z[i], Y[i], X[i] + 1])
if Y[i] > 0 and index[Z[i], Y[i] - 1, X[i]] > 0:
adjacent.append(index[Z[i], Y[i] - 1, X[i]])
if Y[i] < ny - 1 and index[Z[i], Y[i] + 1, X[i]] > 0:
adjacent.append(index[Z[i], Y[i] + 1, X[i]])
if Z[i] > 0 and index[Z[i] - 1, Y[i], X[i]] > 0:
adjacent.append(index[Z[i] - 1, Y[i], X[i]])
if Z[i] < nz - 1 and index[Z[i] + 1, Y[i], X[i]] > 0:
adjacent.append(index[Z[i] + 1, Y[i], X[i]])
for j in range(len(adjacent)):
if ancestor[adjacent[j]] is not None:
adjacent[j] = ancestor[adjacent[j]]
adjacent = list(set(adjacent))
n_adjacent = len(adjacent)
if n_adjacent == 0: # Create new leaf
# Set absolute index of the new element
idx = next_idx()
# Create leaf
leaf = Leaf(X[i], Y[i], Z[i], flux[i])
# Add leaf to overall list
items[idx] = leaf
# Set absolute index of pixel in index map
index[Z[i], Y[i], X[i]] = idx
# Create new entry for ancestor
ancestor[idx] = None
elif n_adjacent == 1: # Add to existing leaf or branch
# Get absolute index of adjacent element
idx = adjacent[0]
# Get adjacent item
item = items[idx]
# Add point to item
item.add_point(X[i], Y[i], Z[i], flux[i])
# Set absolute index of pixel in index map
index[Z[i], Y[i], X[i]] = idx
else: # Merge leaves
# At this stage, the adjacent items might consist of an arbitrary
# number of leaves and branches.
# Find all leaves that are not important enough to be kept separate
merge = []
for idx in adjacent:
if type(items[idx]) == Leaf:
leaf = items[idx]
if leaf.npix < PAR['npix'] or leaf.fmax - flux[i] < PAR['delta']:
merge.append(idx)
# Remove merges from list of adjacent items
for idx in merge:
adjacent.remove(idx)
# If there is only one item left, then if it is a leaf, add to the
# list to merge, and if it is a branch then add the merges to the
# branch.
if len(adjacent) == 0:
# There are no separate leaves left (and no branches), so pick the
# first one as the reference and merge all the others onto it
idx = merge[0]
leaf = items[idx]
# Add current point to the leaf
leaf.add_point(X[i], Y[i], Z[i], flux[i])
# Set absolute index of pixel in index map
index[Z[i], Y[i], X[i]] = idx
for i in merge[1:]:
# print "Merging leaf %i onto leaf %i" % (i, idx)
# Remove leaf
removed = items.pop(i)
# Merge old leaf onto reference leaf
leaf.merge(removed)
# Update index map
index = removed.add_footprint(index, idx)
elif len(adjacent) == 1:
if type(items[adjacent[0]]) == Leaf:
idx = adjacent[0]
leaf = items[idx]
# Add current point to the leaf
leaf.add_point(X[i], Y[i], Z[i], flux[i])
# Set absolute index of pixel in index map
index[Z[i], Y[i], X[i]] = idx
for i in merge:
# print "Merging leaf %i onto leaf %i" % (i, idx)
# Remove leaf
removed = items.pop(i)
# Merge old leaf onto reference leaf
leaf.merge(removed)
# Update index map
index = removed.add_footprint(index, idx)
else:
idx = adjacent[0]
branch = items[idx]
# Add current point to the branch
branch.add_point(X[i], Y[i], Z[i], flux[i])
# Set absolute index of pixel in index map
index[Z[i], Y[i], X[i]] = idx
for i in merge:
# print "Merging leaf %i onto branch %i" % (i, idx)
# Remove leaf
removed = items.pop(i)
# Merge old leaf onto reference leaf
branch.merge(removed)
# Update index map
index = removed.add_footprint(index, idx)
else:
# Set absolute index of the new element
idx = next_idx()
# Create branch
branch = Branch([items[j] for j in adjacent], \
X[i], Y[i], Z[i], flux[i])
# Add branch to overall list
items[idx] = branch
# Set absolute index of pixel in index map
index[Z[i], Y[i], X[i]] = idx
# Create new entry for ancestor
ancestor[idx] = None
for i in merge:
# print "Merging leaf %i onto branch %i" % (i, idx)
# Remove leaf
removed = items.pop(i)
# Merge old leaf onto reference leaf
branch.merge(removed)
# Update index map
index = removed.add_footprint(index, idx)
for j in adjacent:
ancestor[j] = idx
for a in ancestor:
if ancestor[a] == j:
ancestor[a] = idx
# Remove orphan leaves that aren't large enough
remove = []
for idx in items:
item = items[idx]
if type(item) == Leaf:
if item.npix < PAR['npix'] or item.fmax - item.fmin < PAR['delta']:
remove.append(idx)
for idx in remove:
items.pop(idx)
# Create trunk from objects with no ancestors
trunk = Trunk()
for idx in items:
if ancestor[idx] is None:
trunk.append(items[idx])
leaf_id = 1
for item in trunk:
leaf_id = item.set_id(leaf_id)
if PAR['format'] == 'fits':
import pyfits
# Create reverse dictionary
reverse = {}
for key in items:
reverse[items[key]] = key
leaves = []
branches = {}
for idx in items:
if type(items[idx]) == Leaf:
leaves.append(idx)
else:
branches[idx] = []
for idx in items:
if items[idx].parent is not None:
pidx = reverse[items[idx].parent]
branches[pidx].append(idx)
leaves_table = np.array(zip(leaves), dtype=[('id', int)])
branches_table = np.zeros((len(branches)), dtype=[('id', int), ('children', int, 6)])
for i, idx in enumerate(branches):
for j, cidx in enumerate(branches[idx]):
branches_table['id'][i] = idx
branches_table['children'][i, j] = cidx
hdu_index = pyfits.PrimaryHDU(data=index)
hdu_leaves = pyfits.BinTableHDU(data=leaves_table, name="Leaves")
hdu_branches = pyfits.BinTableHDU(data=branches_table, name="Branches")
hdulist = pyfits.HDUList([hdu_index, hdu_leaves, hdu_branches])
hdulist.writeto(PAR['output'])
# # Compute level footprint
# footprint = np.zeros(data.shape)
# for item in trunk:
# footprint = item.add_footprint(footprint, 1)
# pyfits.writeto('%s' % PAR['output'], footprint, clobber=True)
# os.system('gzip --best %s' % PAR['output'])
#
# # Make a plot
# fig = plt.figure()
# ax = fig.add_subplot(1, 2, 1)
# ax.imshow(data[0, :, :], origin='lower', interpolation='nearest')
# ax.set_title("Original data")
# ax = fig.add_subplot(1, 2, 2)
# ax.imshow(footprint[0, :, :], origin='lower', interpolation='nearest')
# ax.set_title('Levels')
# fig.savefig('structure.png')
#
# # Make a plot
# fig = plt.figure(figsize=(200, 20))
# ax = fig.add_subplot(1, 1, 1)
# lines = []
# for item in trunk:
# lines = item.plot_dendrogram(ax, PAR['minimum'], lines)
# ax.add_collection(LineCollection(lines, lw=0.25))
# ax.set_xlim(0., float(leaf_id + 1))
# ax.set_ylim(max(PAR['minimum'], data.max() / 1000.), data.max())
# ax.set_yscale('log')
# fig.savefig('dendrogram.pdf')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment