Last active
February 16, 2022 04:49
-
-
Save e96031413/704e8a316befedd0dea659f27fc34461 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def eval(loader, gt_labels_t, output_file="output.txt"): | |
G.eval() # 特徵提取器 | |
F1.eval() # 分類器 | |
size = 0 | |
correct = 0 | |
y_pred=[] | |
y_true=[] | |
pred_prob = None | |
pred_result = None | |
features = None | |
labels = [] | |
with torch.no_grad(): | |
for batch_idx, data_t in enumerate(tqdm(loader)): | |
feat_list = [] | |
def hook(module, input, output): | |
feat_list.append(output.clone().detach()) | |
handle=G.avgpool.register_forward_hook(hook) | |
im_data_t.resize_(data_t[0].size()).copy_(data_t[0]) | |
gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1]) | |
paths = data_t[2] | |
component_name = data_t[3] | |
feat = G(im_data_t) | |
feats = torch.flatten(feat_list[0], 1).unsqueeze_(-1).unsqueeze_(-1) | |
feat_tsne = torch.flatten(feat_list[0], 1) | |
handle.remove() | |
# Classification | |
output1= F1(feat) | |
size += im_data_t.size(0) | |
pred1 = output1.data.max(1)[1] | |
# softmax prob | |
softmax_prob = F.softmax(output1, dim=1) | |
current_pred_prob = softmax_prob[:,0].cpu().numpy() # 類別good的機率 | |
current_features = feat_tsne.cpu().numpy() # tsne visualize feature | |
current_prediction = pred1.cpu().numpy() | |
# features | |
if features is not None: | |
features = np.concatenate((features, current_features)) | |
else: | |
features = current_features | |
# prediction probability | |
if pred_prob is not None: | |
pred_prob = np.concatenate((pred_prob, current_pred_prob)) | |
else: | |
pred_prob = current_pred_prob | |
# prediction result | |
if pred_result is not None: | |
pred_result = np.concatenate((pred_result, current_prediction)) | |
else: | |
pred_result = current_prediction | |
target = gt_labels_t.tolist() | |
for element in target: | |
labels.append(element) | |
y_pred.extend(pred1.view(-1).detach().cpu().numpy()) | |
y_true.extend(gt_labels_t.view(-1).detach().cpu().numpy()) | |
correct += pred1.eq(gt_labels_t.data).cpu().sum() | |
print('\n Accuracy: {}/{} F1 ({:.4f}%)\n'.format(correct, size, | |
100. * float(correct) / size)) | |
return features, labels, pred_prob, pred_result | |
# 使用Testing set | |
features, labels, pred_prob, pred_result = eval(target_loader_unl, gt_labels_t) | |
tsne = TSNE(n_components=2, random_state=999).fit_transform(features) | |
xx = tsne[:, 0] | |
yy = tsne[:, 1] | |
fig = plt.figure(figsize = (20, 20), dpi=80) | |
mappable = plt.tricontourf(xx.ravel(), yy.ravel(), pred_prob.ravel(), cmap=plt.cm.Spectral) | |
scatter = plt.scatter(xx, yy, c=labels, label=labels) | |
label_name = ['good', 'bad'] | |
plt.legend(handles=scatter.legend_elements()[0], labels=label_name) | |
fig.colorbar(mappable, ticks=np.linspace(0, 1., 9)) | |
plt.savefig('decision-boundary.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
範例圖片: