Skip to content

Instantly share code, notes, and snippets.

@davidandrzej
Created April 24, 2011 20:11
Show Gist options
  • Save davidandrzej/939840 to your computer and use it in GitHub Desktop.
Save davidandrzej/939840 to your computer and use it in GitHub Desktop.
3-simplex triangular scatter plot
"""
Visualize points on the 3-simplex (eg, the parameters of a
3-dimensional multinomial distributions) as a scatter plot
contained within a 2D triangle.
David Andrzejewski (david.andrzej@gmail.com)
"""
import numpy as NP
import matplotlib.pyplot as P
import matplotlib.ticker as MT
import matplotlib.lines as L
import matplotlib.cm as CM
import matplotlib.colors as C
import matplotlib.patches as PA
def plotSimplex(points, fig=None,
vertexlabels=['1','2','3'],
**kwargs):
"""
Plot Nx3 points array on the 3-simplex
(with optionally labeled vertices)
kwargs will be passed along directly to matplotlib.pyplot.scatter
Returns Figure, caller must .show()
"""
if(fig == None):
fig = P.figure()
# Draw the triangle
l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords
[0, NP.sqrt(3) / 2, 0, 0], # ycoords
color='k')
fig.gca().add_line(l1)
fig.gca().xaxis.set_major_locator(MT.NullLocator())
fig.gca().yaxis.set_major_locator(MT.NullLocator())
# Draw vertex labels
fig.gca().text(-0.05, -0.05, vertexlabels[0])
fig.gca().text(1.05, -0.05, vertexlabels[1])
fig.gca().text(0.5, NP.sqrt(3) / 2 + 0.05, vertexlabels[2])
# Project and draw the actual points
projected = projectSimplex(points)
P.scatter(projected[:,0], projected[:,1], **kwargs)
# Leave some buffer around the triangle for vertex labels
fig.gca().set_xlim(-0.2, 1.2)
fig.gca().set_ylim(-0.2, 1.2)
return fig
def projectSimplex(points):
"""
Project probabilities on the 3-simplex to a 2D triangle
N points are given as N x 3 array
"""
# Convert points one at a time
tripts = NP.zeros((points.shape[0],2))
for idx in range(points.shape[0]):
# Init to triangle centroid
x = 1.0 / 2
y = 1.0 / (2 * NP.sqrt(3))
# Vector 1 - bisect out of lower left vertex
p1 = points[idx, 0]
x = x - (1.0 / NP.sqrt(3)) * p1 * NP.cos(NP.pi / 6)
y = y - (1.0 / NP.sqrt(3)) * p1 * NP.sin(NP.pi / 6)
# Vector 2 - bisect out of lower right vertex
p2 = points[idx, 1]
x = x + (1.0 / NP.sqrt(3)) * p2 * NP.cos(NP.pi / 6)
y = y - (1.0 / NP.sqrt(3)) * p2 * NP.sin(NP.pi / 6)
# Vector 3 - bisect out of top vertex
p3 = points[idx, 2]
y = y + (1.0 / NP.sqrt(3) * p3)
tripts[idx,:] = (x,y)
return tripts
if __name__ == '__main__':
# Define a synthetic test dataset
labels = ('[0.1 0.1 0.8]',
'[0.8 0.1 0.1]',
'[0.5 0.4 0.1]',
'[0.33 0.34 0.33]')
testpoints = NP.array([[0.1, 0.1, 0.8],
[0.8, 0.1, 0.1],
[0.5, 0.4, 0.1],
[0.33, 0.34, 0.33]])
# Define different colors for each label
cmap = CM.get_cmap('spectral')
norm = C.Normalize(vmin=0, vmax=len(labels))
c = range(len(labels))
# Do scatter plot
fig = plotSimplex(testpoints, s=100, c=c,
cmap=cmap, norm=norm)
# Make color-label legend
P.legend([PA.Rectangle((0, 0), 1, 1,
fc=cmap(norm(idx)))
for idx in range(len(labels))],
labels)
P.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment