Skip to content

Instantly share code, notes, and snippets.

@yjzhang
Created June 13, 2019 22:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yjzhang/226b8ddef8d6bd7be5a6709f7970b1a5 to your computer and use it in GitHub Desktop.
Save yjzhang/226b8ddef8d6bd7be5a6709f7970b1a5 to your computer and use it in GitHub Desktop.
spectral coclustering heatmap
import numpy as np
from sklearn.cluster.bicluster import SpectralCoclustering
spec = SpectralCoclustering(18)
cluster_counts_subset = np.vstack([cluster_counts[:31, :], cluster_counts[32:,:]])
spec.fit(cluster_counts + 0.0001)
row_labels = spec.row_labels_
column_labels = spec.column_labels_
row_order = np.argsort(row_labels)
col_order = np.argsort(column_labels)
cluster_counts_reordered = cluster_counts[row_order, :]
cluster_counts_reordered = cluster_counts_reordered[:, col_order]
cluster_cell_types_2 = np.array([str(x) + ': ' + y for x, y in zip(row_labels, cluster_cell_types)])
col_labels = np.array([str(x) + ': ' + str(y) for x, y in zip(column_labels, range(len(column_labels)))])
plt.figure(figsize=(25, 30))
ax = sns.heatmap(cluster_counts_reordered/cluster_counts_reordered.sum(1)[:,np.newaxis],
yticklabels=cluster_cell_types_2[row_order],
xticklabels=col_labels[col_order],
vmin=0, vmax=1)
prev_label = 0
for i, c in enumerate(row_labels[row_order]):
if c != prev_label:
ax.axhline(i, linewidth=3)
prev_label = c
prev_label = 0
for i, c in enumerate(row_labels[row_order]):
if c != prev_label:
ax.axhline(i, linewidth=3)
prev_label = c
prev_label = 0
for i, c in enumerate(column_labels[col_order]):
if c != prev_label:
ax.axvline(i, linewidth=3)
prev_label = c
plt.xlabel('UNCURL clusters')
plt.ylabel('Seurat clusters')
plt.title('SCH Cerebellum Clusters')
plt.savefig('uncurl_vs_seurat_clusters_coclustering.png', dpi=150)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment