Skip to content

Instantly share code, notes, and snippets.

@e96031413
Last active February 16, 2022 04:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/704e8a316befedd0dea659f27fc34461 to your computer and use it in GitHub Desktop.
Save e96031413/704e8a316befedd0dea659f27fc34461 to your computer and use it in GitHub Desktop.
def eval(loader, gt_labels_t, output_file="output.txt"):
G.eval() # 特徵提取器
F1.eval() # 分類器
size = 0
correct = 0
y_pred=[]
y_true=[]
pred_prob = None
pred_result = None
features = None
labels = []
with torch.no_grad():
for batch_idx, data_t in enumerate(tqdm(loader)):
feat_list = []
def hook(module, input, output):
feat_list.append(output.clone().detach())
handle=G.avgpool.register_forward_hook(hook)
im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
paths = data_t[2]
component_name = data_t[3]
feat = G(im_data_t)
feats = torch.flatten(feat_list[0], 1).unsqueeze_(-1).unsqueeze_(-1)
feat_tsne = torch.flatten(feat_list[0], 1)
handle.remove()
# Classification
output1= F1(feat)
size += im_data_t.size(0)
pred1 = output1.data.max(1)[1]
# softmax prob
softmax_prob = F.softmax(output1, dim=1)
current_pred_prob = softmax_prob[:,0].cpu().numpy() # 類別good的機率
current_features = feat_tsne.cpu().numpy() # tsne visualize feature
current_prediction = pred1.cpu().numpy()
# features
if features is not None:
features = np.concatenate((features, current_features))
else:
features = current_features
# prediction probability
if pred_prob is not None:
pred_prob = np.concatenate((pred_prob, current_pred_prob))
else:
pred_prob = current_pred_prob
# prediction result
if pred_result is not None:
pred_result = np.concatenate((pred_result, current_prediction))
else:
pred_result = current_prediction
target = gt_labels_t.tolist()
for element in target:
labels.append(element)
y_pred.extend(pred1.view(-1).detach().cpu().numpy())
y_true.extend(gt_labels_t.view(-1).detach().cpu().numpy())
correct += pred1.eq(gt_labels_t.data).cpu().sum()
print('\n Accuracy: {}/{} F1 ({:.4f}%)\n'.format(correct, size,
100. * float(correct) / size))
return features, labels, pred_prob, pred_result
# 使用Testing set
features, labels, pred_prob, pred_result = eval(target_loader_unl, gt_labels_t)
tsne = TSNE(n_components=2, random_state=999).fit_transform(features)
xx = tsne[:, 0]
yy = tsne[:, 1]
fig = plt.figure(figsize = (20, 20), dpi=80)
mappable = plt.tricontourf(xx.ravel(), yy.ravel(), pred_prob.ravel(), cmap=plt.cm.Spectral)
scatter = plt.scatter(xx, yy, c=labels, label=labels)
label_name = ['good', 'bad']
plt.legend(handles=scatter.legend_elements()[0], labels=label_name)
fig.colorbar(mappable, ticks=np.linspace(0, 1., 9))
plt.savefig('decision-boundary.png')
@e96031413
Copy link
Author

範例圖片:
decisionBoundary

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment