Created
June 9, 2020 14:39
-
-
Save Vini2/6586736ec9958ea2902fca98200def76 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import scipy.special | |
import csv | |
parser = argparse.ArgumentParser(description="""Evaluate clustering results. This scripts will return the | |
precision, recall, F1-score and ARI of the provided clustering result""") | |
parser.add_argument("--clustered", | |
required=True, | |
type=str, | |
help="path to the .csv file with the clustering result") | |
parser.add_argument("--goldstandard", | |
required=True, | |
type=str, | |
help="path to the .csv file with the gold standard") | |
args = vars(parser.parse_args()) | |
# Get paths to clustering result and gold standard | |
clustered_file = args["clustered"] | |
gold_standard_file = args["goldstandard"] | |
print("\nStarting evaluate.py...") | |
print("Clustering results file:", clustered_file) | |
print("gold standard file:", gold_standard_file) | |
# Get the number of clusters from the gold standard | |
#--------------------------------------------------------- | |
gold_standard_n_clusters = 0 | |
all_gold_standard_clusters_list = [] | |
with open(gold_standard_file) as csvfile: | |
readCSV = csv.reader(csvfile, delimiter=',') | |
for row in readCSV: | |
all_gold_standard_clusters_list.append(row[1]) | |
gold_standard_clusters_list = list(set(all_gold_standard_clusters_list)) | |
gold_standard_n_clusters = len(gold_standard_clusters_list) | |
print("\nNumber of clusters available in the gold standard: ", gold_standard_n_clusters) | |
# Get the gold standard | |
#---------------------------- | |
gold_standard_clusters = [[] for x in range(gold_standard_n_clusters)] | |
gold_standard_count = 0 | |
with open(gold_standard_file) as contig_clusters: | |
readCSV = csv.reader(contig_clusters, delimiter=',') | |
for row in readCSV: | |
gold_standard_count += 1 | |
contig = row[0] | |
bin_num = gold_standard_clusters_list.index(row[1]) | |
gold_standard_clusters[bin_num].append(contig) | |
print("Number of objects available in the gold standard: ", gold_standard_count) | |
# Get the number of clusters from the initial clustering result | |
#--------------------------------------------------------- | |
n_clusters = 0 | |
all_clusters_list = [] | |
with open(clustered_file) as csvfile: | |
readCSV = csv.reader(csvfile, delimiter=',') | |
for row in readCSV: | |
all_clusters_list.append(row[1]) | |
clusters_list = list(set(all_clusters_list)) | |
n_clusters = len(clusters_list) | |
print("Number of clusters available in the clustering result: ", n_clusters) | |
# Get initial clustering result | |
#---------------------------- | |
clusters = [[] for x in range(n_clusters)] | |
clustered_count = 0 | |
clustered_objects = [] | |
with open(clustered_file) as contig_clusters: | |
readCSV = csv.reader(contig_clusters, delimiter=',') | |
for row in readCSV: | |
clustered_count += 1 | |
contig = row[0] | |
bin_num = clusters_list.index(row[1]) | |
clusters[bin_num].append(contig) | |
clustered_objects.append(contig) | |
print("Number of objects available in the clustering result: ", len(clustered_objects)) | |
# Functions to determine precision, recall, F1-score and ARI | |
#------------------------------------------------------------ | |
# Get precicion | |
def getPrecision(mat, k, s, total): | |
sum_k = 0 | |
for i in range(k): | |
max_s = 0 | |
for j in range(s): | |
if mat[i][j] > max_s: | |
max_s = mat[i][j] | |
sum_k += max_s | |
return sum_k/total*100 | |
# Get recall | |
def getRecall(mat, k, s, total, unclassified): | |
sum_s = 0 | |
for i in range(s): | |
max_k = 0 | |
for j in range(k): | |
if mat[j][i] > max_k: | |
max_k = mat[j][i] | |
sum_s += max_k | |
return sum_s/(total+unclassified)*100 | |
# Get ARI | |
def getARI(mat, k, s, N): | |
t1 = 0 | |
for i in range(k): | |
sum_k = 0 | |
for j in range(s): | |
sum_k += mat[i][j] | |
t1 += scipy.special.binom(sum_k, 2) | |
t2 = 0 | |
for i in range(s): | |
sum_s = 0 | |
for j in range(k): | |
sum_s += mat[j][i] | |
t2 += scipy.special.binom(sum_s, 2) | |
t3 = t1*t2/scipy.special.binom(N, 2) | |
t = 0 | |
for i in range(k): | |
for j in range(s): | |
t += scipy.special.binom(mat[i][j], 2) | |
ari = (t-t3)/((t1+t2)/2-t3)*100 | |
return ari | |
# Get F1-score | |
def getF1(prec, recall): | |
return 2*prec*recall/(prec+recall) | |
# Determine precision, recall, F1-score and ARI for clustering result | |
#------------------------------------------------------------------ | |
total_clustered = 0 | |
clusters_species = [[0 for x in range(gold_standard_n_clusters)] for y in range(n_clusters)] | |
for i in range(n_clusters): | |
for j in range(gold_standard_n_clusters): | |
n = 0 | |
for k in range(clustered_count): | |
if clustered_objects[k] in clusters[i] and clustered_objects[k] in gold_standard_clusters[j]: | |
n+=1 | |
total_clustered += 1 | |
clusters_species[i][j] = n | |
print("Number of objects available in the clustering result that are present in the gold standard:", total_clustered) | |
my_precision = getPrecision(clusters_species, n_clusters, gold_standard_n_clusters, total_clustered) | |
my_recall = getRecall(clusters_species, n_clusters, gold_standard_n_clusters, total_clustered, (gold_standard_count-total_clustered)) | |
my_ari = getARI(clusters_species, n_clusters, gold_standard_n_clusters, total_clustered) | |
my_f1 = getF1(my_precision, my_recall) | |
print("\nEvaluation Results:") | |
print("Precision =", my_precision) | |
print("Recall =", my_recall) | |
print("F1-score =", my_f1) | |
print("ARI =", my_ari) | |
print() |
Hello @afsawadogo,
The clustering_result.csv
file would look something like this.
NODE_1_length_1189502_cov_16.379288,res.005.fasta
NODE_2_length_1127036_cov_16.549343,res.005.fasta
NODE_3_length_1009819_cov_16.436396,res.005.fasta
NODE_4_length_861895_cov_21.063754,res.004.fasta
NODE_5_length_737013_cov_20.834031,res.004.fasta
NODE_6_length_659011_cov_21.171279,res.004.fasta
...
Make sure to follow the format point_id,bin_id
.
The gold_standard.csv
file would look something like this.
NODE_1_length_1189502_cov_16.379288,Azorhizobium_caulinodans.fa
NODE_2_length_1127036_cov_16.549343,Azorhizobium_caulinodans.fa
NODE_3_length_1009819_cov_16.436396,Azorhizobium_caulinodans.fa
NODE_4_length_861895_cov_21.063754,Amycolatopsis_mediterranei.fa
NODE_5_length_737013_cov_20.834031,Amycolatopsis_mediterranei.fa
NODE_6_length_659011_cov_21.171279,Amycolatopsis_mediterranei.fa
...
Make sure to follow the format point_id,groundtruth_id
.
bin_id
and groundtruth_id
need not be the same. You can use any identifier.
Hope this helps. Let me know if you have any more questions.
Thank you!
Vijini
Hello @Vini2,
Thank you for your quick reply.
I doing a clustering using the Kmeans algorithm and I want to evaluate my clustering result using your code. I'm very new in Machine Learning and data science.
Thanks again
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello Vijini, thank you for the script. Would you mind show me an example of the clustering_result.csv and gold_standard.csv files.
Thank you