Skip to content

Instantly share code, notes, and snippets.

@e96031413
Last active June 11, 2022 11:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/ad449b4f29eb36fa5b780b2590f3939c to your computer and use it in GitHub Desktop.
Save e96031413/ad449b4f29eb36fa5b780b2590f3939c to your computer and use it in GitHub Desktop.
# evaluate the clustering performance
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import normalized_mutual_info_score
import numpy as np
def evaluation(X, Y):
classN = np.max(Y)+1
kmeans = KMeans(n_clusters=classN).fit(X)
nmi = normalized_mutual_info_score(Y, kmeans.labels_, average_method='arithmetic')
return nmi
output = model(input)
testdata = torch.cat((testdata, output.cpu()), 0)
testlabel = torch.cat((testlabel, target))
nmi = evaluation(testdata.numpy(), testlabel.numpy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment