Created
February 26, 2019 07:52
-
-
Save escuccim/fc81a81c433483337303cc330e2e2bc1 to your computer and use it in GitHub Desktop.
Cluster Purity
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy import stats | |
## Cluster purity | |
def purity(truth, pred): | |
cluster_purities = [] | |
# loop through clusters and calculate purity for each | |
for pred_cluster in np.unique(pred): | |
filter_ = pred == pred_cluster | |
gt_partition = truth[filter_] | |
pred_partition = pred[filter_] | |
# figure out which gt partition this predicted cluster contains the most points of | |
mode_ = stats.mode(gt_partition) | |
max_gt_cluster = mode_[0][0] | |
# how many points in the max cluster does the current cluster contain | |
pure_members = np.sum(gt_partition == max_gt_cluster) | |
cluster_purity = pure_members / len(pred_partition) | |
cluster_purities.append(pure_members) | |
return np.sum(cluster_purities) / len(pred) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment