Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save mokemokechicken/5599e174d060a75a94aaea69597c786f to your computer and use it in GitHub Desktop.
Save mokemokechicken/5599e174d060a75a94aaea69597c786f to your computer and use it in GitHub Desktop.
scatter with colors and legends
p_list = np.argmax(api.predict(dataset.x), axis=1)
color_master = sns.color_palette()
colors = [color_master[ptn] for ptn in p_list]
names = api.predict_names()
g = sns.JointGrid("x", "y", data=zz)
g = g.plot_joint(plt.scatter, c=colors, label=p_list)
g = g.plot_marginals(sns.distplot, kde=True)
bb = g.ax_joint.viewLim
x_size = bb.xmax-bb.xmin
y_size = bb.ymax-bb.ymin
for i, (name, c) in enumerate(zip(names, color_master)):
g.ax_joint.text(bb.xmin+x_size*0.01, bb.ymax-(i+2)*y_size*0.05, "● "+name, color=c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment