Created
January 9, 2020 16:34
-
-
Save alinazhanguwo/beb8f807d3311ea71289f43dc837318d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def overallAccuracy(clusterDF, labelsDF): | |
countByCluster = pd.DataFrame(data=clusterDF['cluster'].value_counts()) | |
countByCluster.reset_index(inplace=True, drop=False) | |
countByCluster.columns = ['cluster', 'clusterCount'] | |
# print('countByCluster \n', countByCluster) | |
preds = pd.concat([labelsDF, clusterDF], axis=1) | |
preds.columns = ['trueLabel', 'cluster'] | |
# print('preds \n', preds) | |
countByLabel = pd.DataFrame(data=preds.groupby('trueLabel').count()) | |
print('countByLabel \n', countByLabel) | |
''' | |
lambda x: x.value_counts().iloc[0]) | |
will return the most freq true label for each cluster | |
i.e. for cluster 0, the true label would be 1 | |
for cluster 1, the true label would be 2 | |
cluster trueLabel | |
0 1 20739 | |
3 18 | |
2 10 | |
1 2 19923 | |
3 2857 | |
1 350 | |
2 3 8732 | |
2 2219 | |
1 1860 | |
''' | |
countMostFreqLabel = pd.DataFrame(data=preds.groupby('cluster').agg( \ | |
{lambda x: x.value_counts().tolist()[0], \ | |
lambda x: x.value_counts().keys().tolist()[0]})) | |
countMostFreqLabel.reset_index(inplace=True, drop=False) | |
countMostFreqLabel.columns = ['cluster', 'countMostFreqLabel','lable'] | |
print('countMostFreqLabel \n', countMostFreqLabel,'\n \n \n') | |
accuracyDF = countMostFreqLabel.merge(countByCluster, left_on="cluster", right_on="cluster") | |
print('accuracyDF: i.e. dots clustered as A, how many of them have real label A \n', accuracyDF) | |
overallAccuracy = accuracyDF.countMostFreqLabel.sum() / accuracyDF.clusterCount.sum() | |
print('overallAccuracy \n', overallAccuracy) | |
accuracyByLabel = accuracyDF.countMostFreqLabel / accuracyDF.clusterCount | |
print('accuracyByLabel \n', accuracyByLabel, '\n===================================\n \n \n') | |
return countByCluster, countByLabel, countMostFreqLabel, accuracyDF, overallAccuracy, accuracyByLabel |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment