Skip to content

Instantly share code, notes, and snippets.

@MortisHuang
Created July 24, 2019 07:03
Show Gist options
  • Save MortisHuang/d756357c85aea83bb0b7790e696b3eb1 to your computer and use it in GitHub Desktop.
Save MortisHuang/d756357c85aea83bb0b7790e696b3eb1 to your computer and use it in GitHub Desktop.
model.compile(optimizer=SGD(0.01, 0.9), loss='kld')
#%%
kmeans = KMeans(n_clusters=n_clusters, n_init=20)
y_pred = kmeans.fit_predict(encoder.predict(x))
#%%
y_pred_last = np.copy(y_pred)
model.get_layer(name='clustering').set_weights([kmeans.cluster_centers_])
#%%
# computing an auxiliary target distribution
def target_distribution(q):
weight = q ** 2 / q.sum(0)
return (weight.T / weight.sum(1)).T
loss = 0
index = 0
maxiter = 8000
update_interval = 140
index_array = np.arange(x.shape[0])
tol = 0.001 # tolerance threshold to stop training
#%%
for ite in range(int(maxiter)):
if ite % update_interval == 0:
q = model.predict(x, verbose=0)
p = target_distribution(q) # update the auxiliary target distribution p
# evaluate the clustering performance
y_pred = q.argmax(1)
if y is not None:
acc = np.round(accu(y, y_pred), 5)
nmi = np.round(nmis(y, y_pred), 5)
ari = np.round(aris(y, y_pred), 5)
loss = np.round(loss, 5)
print('Iter %d: acc = %.5f, nmi = %.5f, ari = %.5f' % (ite, acc, nmi, ari), ' ; loss=', loss)
# check stop criterion - model convergence
delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0]
y_pred_last = np.copy(y_pred)
if ite > 0 and delta_label < tol:
print('delta_label ', delta_label, '< tol ', tol)
print('Reached tolerance threshold. Stopping training.')
break
idx = index_array[index * batch_size: min((index+1) * batch_size, x.shape[0])]
loss = model.train_on_batch(x=x[idx], y=p[idx])
index = index + 1 if (index + 1) * batch_size <= x.shape[0] else 0
model.save_weights(save_dir + '/DEC_model_final.h5')
model.load_weights(save_dir + '/DEC_model_final.h5')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment