Skip to content

Instantly share code, notes, and snippets.

@FeryET
Last active August 26, 2020 11:10
Show Gist options
  • Save FeryET/60f322cfec4029c1e34f0c2ff1e8ad75 to your computer and use it in GitHub Desktop.
Save FeryET/60f322cfec4029c1e34f0c2ff1e8ad75 to your computer and use it in GitHub Desktop.
def plot_topic_clusters(ax, x2d, y, labels):
ax.set_aspect("equal")
colors = cm.get_cmap("Spectral", len(labels))
for i, l in enumerate(labels):
c = colors(i / len(labels))
ax.scatter(x2d[y == i, 0], x2d[y == i, 1], color=c, label=l, alpha=0.7)
ax.grid()
ax.legend()
ax.set(adjustable='box', aspect='equal')
return ax
title = "PCA Visualization of the Dataset using {}"
if use_umap is True:
from umap import UMAP
dim_reducer = UMAP(n_components=2)
title = title.format("UMAP")
else:
from sklearn.manifold import TSNE
dim_reducer = TSNE(n_components=2)
title = title.format("TSNE")
x_transform = np.concatenate((x_train, x_test))
x_transform = StandardScaler().fit_transform(x_transform)
x_transform = dim_reducer.fit_transform(x_transform)
x2d_train = x_transform[:x_train.shape[0], :]
x2d_test = x_transform[x_train.shape[0]:, :]
fig, axes = plt.subplots(ncols=2, sharex=True, sharey=True)
plot_topic_clusters(axes[0], x2d_train, y_train, labels)
plot_topic_clusters(axes[1], x2d_test, y_test, labels)
axes[0].set_title("Train Subset")
axes[1].set_title("Test Subset")
fig.suptitle(title)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment