Skip to content

Instantly share code, notes, and snippets.

@duner
Last active October 29, 2015 22:37
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
Ternary Scatter Plot MatPlotLib
import numpy
import sys
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import matplotlib.cm as cm
def create_chart(data):
"""
data should be an list of lists in the form [(x,y,z), a]
"""
SQRT3 = numpy.sqrt(3)
SQRT3OVER2 = SQRT3 / 2.
def unzip(l):
return zip(*l)
def permute_point(p, permutation=None):
if not permutation:
return p
return [p[int(permutation[i])] for i in range(len(p))]
def project_point(p, permutation=None):
permuted = permute_point(p, permutation=permutation)
a = permuted[0]
b = permuted[1]
x = a + b/2.
y = SQRT3OVER2 * b
return numpy.array([x, y])
def project_sequence(s, permutation=None):
xs, ys = unzip([project_point(p, permutation=permutation) for p in s])
return xs, ys
data = numpy.array(data)
xs, ys = project_sequence(data[:,0])
vs = tuple(data[:,-1])
fig = plt.figure(num=None, figsize=(10, 6), dpi=80, facecolor='w', edgecolor='k')
corners = numpy.array([[0, 0], [4, 0], [2, numpy.sqrt(3) * 0.5 * 4]])
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
# creating the grid
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=2)
#plotting the colorbar
colormap = plt.cm.get_cmap('Reds')
#plotting the mesh
plt.triplot(trimesh, '', color='0.9', zorder = 1)
#plotting the points
plt.scatter(xs, ys, c=vs, s=100, zorder = 10, cmap=colormap)
#plotting the axes
plt.plot([corners[0][0], corners[1][0]], [corners[0][1], corners[1][1]], color='0.7', linestyle='-', linewidth=2)
plt.plot([corners[0][0], corners[2][0]], [corners[0][1], corners[2][1]], color='0.7', linestyle='-', linewidth=2)
plt.plot([corners[1][0], corners[2][0]], [corners[1][1], corners[2][1]], color='0.7', linestyle='-', linewidth=2)
def plot_ticks(start, stop, tick, n):
r = numpy.linspace(0, 1, n + 1)
xs = start[0] * (1 - r) + stop[0] * r
xs = numpy.vstack((xs, xs + tick[0]))
ys = start[1] * (1 - r) + stop[1] * r
ys = numpy.vstack((ys, ys + tick[1]))
for i in range(0, len(xs.tolist()[1])):
x = xs.tolist()[1][i]
y = ys.tolist()[1][i]
plt.text(x, y, i, ha='center')
plt.plot(xs, ys, 'k', lw=1, color='0.7')
n = 4
tick_size = 0.2
margin = 1
left = corners[0]
right = corners[1]
top = corners[2]
# define vectors for ticks
bottom_tick = tick_size * (right - top) / n
right_tick = tick_size * (top - left) / n
left_tick = tick_size * (left - right) / n
plot_ticks(left, right, bottom_tick, n)
plot_ticks(right, top, right_tick, n)
plot_ticks(left, top, left_tick, n)
plt.text(2, -.5, "Insertion Cost", ha='center')
plt.text(.5, 2.5, "Deletion Cost", rotation=60, ha='center')
plt.text(3.5, 2.5, "Substitution Cost", rotation=-60, ha='center')
plt.colorbar(label="Error Rate")
plt.axis('off')
# plt.savefig('chart.png')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment