Skip to content

Instantly share code, notes, and snippets.

@duner
Last active October 29, 2015 22:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save duner/4efdb35320c5eeedf71f to your computer and use it in GitHub Desktop.
Save duner/4efdb35320c5eeedf71f to your computer and use it in GitHub Desktop.
Ternary Scatter Plot MatPlotLib
import numpy
import sys
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import matplotlib.cm as cm
def create_chart(data):
"""
data should be an list of lists in the form [(x,y,z), a]
"""
SQRT3 = numpy.sqrt(3)
SQRT3OVER2 = SQRT3 / 2.
def unzip(l):
return zip(*l)
def permute_point(p, permutation=None):
if not permutation:
return p
return [p[int(permutation[i])] for i in range(len(p))]
def project_point(p, permutation=None):
permuted = permute_point(p, permutation=permutation)
a = permuted[0]
b = permuted[1]
x = a + b/2.
y = SQRT3OVER2 * b
return numpy.array([x, y])
def project_sequence(s, permutation=None):
xs, ys = unzip([project_point(p, permutation=permutation) for p in s])
return xs, ys
data = numpy.array(data)
xs, ys = project_sequence(data[:,0])
vs = tuple(data[:,-1])
fig = plt.figure(num=None, figsize=(10, 6), dpi=80, facecolor='w', edgecolor='k')
corners = numpy.array([[0, 0], [4, 0], [2, numpy.sqrt(3) * 0.5 * 4]])
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
# creating the grid
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=2)
#plotting the colorbar
colormap = plt.cm.get_cmap('Reds')
#plotting the mesh
plt.triplot(trimesh, '', color='0.9', zorder = 1)
#plotting the points
plt.scatter(xs, ys, c=vs, s=100, zorder = 10, cmap=colormap)
#plotting the axes
plt.plot([corners[0][0], corners[1][0]], [corners[0][1], corners[1][1]], color='0.7', linestyle='-', linewidth=2)
plt.plot([corners[0][0], corners[2][0]], [corners[0][1], corners[2][1]], color='0.7', linestyle='-', linewidth=2)
plt.plot([corners[1][0], corners[2][0]], [corners[1][1], corners[2][1]], color='0.7', linestyle='-', linewidth=2)
def plot_ticks(start, stop, tick, n):
r = numpy.linspace(0, 1, n + 1)
xs = start[0] * (1 - r) + stop[0] * r
xs = numpy.vstack((xs, xs + tick[0]))
ys = start[1] * (1 - r) + stop[1] * r
ys = numpy.vstack((ys, ys + tick[1]))
for i in range(0, len(xs.tolist()[1])):
x = xs.tolist()[1][i]
y = ys.tolist()[1][i]
plt.text(x, y, i, ha='center')
plt.plot(xs, ys, 'k', lw=1, color='0.7')
n = 4
tick_size = 0.2
margin = 1
left = corners[0]
right = corners[1]
top = corners[2]
# define vectors for ticks
bottom_tick = tick_size * (right - top) / n
right_tick = tick_size * (top - left) / n
left_tick = tick_size * (left - right) / n
plot_ticks(left, right, bottom_tick, n)
plot_ticks(right, top, right_tick, n)
plot_ticks(left, top, left_tick, n)
plt.text(2, -.5, "Insertion Cost", ha='center')
plt.text(.5, 2.5, "Deletion Cost", rotation=60, ha='center')
plt.text(3.5, 2.5, "Substitution Cost", rotation=-60, ha='center')
plt.colorbar(label="Error Rate")
plt.axis('off')
# plt.savefig('chart.png')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment