Skip to content

Instantly share code, notes, and snippets.

@Vini2
Created June 9, 2020 14:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Vini2/6586736ec9958ea2902fca98200def76 to your computer and use it in GitHub Desktop.
Save Vini2/6586736ec9958ea2902fca98200def76 to your computer and use it in GitHub Desktop.
import argparse
import scipy.special
import csv
parser = argparse.ArgumentParser(description="""Evaluate clustering results. This scripts will return the
precision, recall, F1-score and ARI of the provided clustering result""")
parser.add_argument("--clustered",
required=True,
type=str,
help="path to the .csv file with the clustering result")
parser.add_argument("--goldstandard",
required=True,
type=str,
help="path to the .csv file with the gold standard")
args = vars(parser.parse_args())
# Get paths to clustering result and gold standard
clustered_file = args["clustered"]
gold_standard_file = args["goldstandard"]
print("\nStarting evaluate.py...")
print("Clustering results file:", clustered_file)
print("gold standard file:", gold_standard_file)
# Get the number of clusters from the gold standard
#---------------------------------------------------------
gold_standard_n_clusters = 0
all_gold_standard_clusters_list = []
with open(gold_standard_file) as csvfile:
readCSV = csv.reader(csvfile, delimiter=',')
for row in readCSV:
all_gold_standard_clusters_list.append(row[1])
gold_standard_clusters_list = list(set(all_gold_standard_clusters_list))
gold_standard_n_clusters = len(gold_standard_clusters_list)
print("\nNumber of clusters available in the gold standard: ", gold_standard_n_clusters)
# Get the gold standard
#----------------------------
gold_standard_clusters = [[] for x in range(gold_standard_n_clusters)]
gold_standard_count = 0
with open(gold_standard_file) as contig_clusters:
readCSV = csv.reader(contig_clusters, delimiter=',')
for row in readCSV:
gold_standard_count += 1
contig = row[0]
bin_num = gold_standard_clusters_list.index(row[1])
gold_standard_clusters[bin_num].append(contig)
print("Number of objects available in the gold standard: ", gold_standard_count)
# Get the number of clusters from the initial clustering result
#---------------------------------------------------------
n_clusters = 0
all_clusters_list = []
with open(clustered_file) as csvfile:
readCSV = csv.reader(csvfile, delimiter=',')
for row in readCSV:
all_clusters_list.append(row[1])
clusters_list = list(set(all_clusters_list))
n_clusters = len(clusters_list)
print("Number of clusters available in the clustering result: ", n_clusters)
# Get initial clustering result
#----------------------------
clusters = [[] for x in range(n_clusters)]
clustered_count = 0
clustered_objects = []
with open(clustered_file) as contig_clusters:
readCSV = csv.reader(contig_clusters, delimiter=',')
for row in readCSV:
clustered_count += 1
contig = row[0]
bin_num = clusters_list.index(row[1])
clusters[bin_num].append(contig)
clustered_objects.append(contig)
print("Number of objects available in the clustering result: ", len(clustered_objects))
# Functions to determine precision, recall, F1-score and ARI
#------------------------------------------------------------
# Get precicion
def getPrecision(mat, k, s, total):
sum_k = 0
for i in range(k):
max_s = 0
for j in range(s):
if mat[i][j] > max_s:
max_s = mat[i][j]
sum_k += max_s
return sum_k/total*100
# Get recall
def getRecall(mat, k, s, total, unclassified):
sum_s = 0
for i in range(s):
max_k = 0
for j in range(k):
if mat[j][i] > max_k:
max_k = mat[j][i]
sum_s += max_k
return sum_s/(total+unclassified)*100
# Get ARI
def getARI(mat, k, s, N):
t1 = 0
for i in range(k):
sum_k = 0
for j in range(s):
sum_k += mat[i][j]
t1 += scipy.special.binom(sum_k, 2)
t2 = 0
for i in range(s):
sum_s = 0
for j in range(k):
sum_s += mat[j][i]
t2 += scipy.special.binom(sum_s, 2)
t3 = t1*t2/scipy.special.binom(N, 2)
t = 0
for i in range(k):
for j in range(s):
t += scipy.special.binom(mat[i][j], 2)
ari = (t-t3)/((t1+t2)/2-t3)*100
return ari
# Get F1-score
def getF1(prec, recall):
return 2*prec*recall/(prec+recall)
# Determine precision, recall, F1-score and ARI for clustering result
#------------------------------------------------------------------
total_clustered = 0
clusters_species = [[0 for x in range(gold_standard_n_clusters)] for y in range(n_clusters)]
for i in range(n_clusters):
for j in range(gold_standard_n_clusters):
n = 0
for k in range(clustered_count):
if clustered_objects[k] in clusters[i] and clustered_objects[k] in gold_standard_clusters[j]:
n+=1
total_clustered += 1
clusters_species[i][j] = n
print("Number of objects available in the clustering result that are present in the gold standard:", total_clustered)
my_precision = getPrecision(clusters_species, n_clusters, gold_standard_n_clusters, total_clustered)
my_recall = getRecall(clusters_species, n_clusters, gold_standard_n_clusters, total_clustered, (gold_standard_count-total_clustered))
my_ari = getARI(clusters_species, n_clusters, gold_standard_n_clusters, total_clustered)
my_f1 = getF1(my_precision, my_recall)
print("\nEvaluation Results:")
print("Precision =", my_precision)
print("Recall =", my_recall)
print("F1-score =", my_f1)
print("ARI =", my_ari)
print()
@Vini2
Copy link
Author

Vini2 commented Sep 1, 2021

Hello @afsawadogo,

The clustering_result.csv file would look something like this.

NODE_1_length_1189502_cov_16.379288,res.005.fasta
NODE_2_length_1127036_cov_16.549343,res.005.fasta
NODE_3_length_1009819_cov_16.436396,res.005.fasta
NODE_4_length_861895_cov_21.063754,res.004.fasta
NODE_5_length_737013_cov_20.834031,res.004.fasta
NODE_6_length_659011_cov_21.171279,res.004.fasta
...

Make sure to follow the format point_id,bin_id.

The gold_standard.csv file would look something like this.

NODE_1_length_1189502_cov_16.379288,Azorhizobium_caulinodans.fa
NODE_2_length_1127036_cov_16.549343,Azorhizobium_caulinodans.fa
NODE_3_length_1009819_cov_16.436396,Azorhizobium_caulinodans.fa
NODE_4_length_861895_cov_21.063754,Amycolatopsis_mediterranei.fa
NODE_5_length_737013_cov_20.834031,Amycolatopsis_mediterranei.fa
NODE_6_length_659011_cov_21.171279,Amycolatopsis_mediterranei.fa
...

Make sure to follow the format point_id,groundtruth_id.

bin_id and groundtruth_id need not be the same. You can use any identifier.

Hope this helps. Let me know if you have any more questions.

Thank you!
Vijini

@afsawadogo
Copy link

Hello @Vini2,
Thank you for your quick reply.
I doing a clustering using the Kmeans algorithm and I want to evaluate my clustering result using your code. I'm very new in Machine Learning and data science.

Thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment