Created
July 12, 2021 10:07
-
-
Save e96031413/d22b047daaaf8e761924aa58de970012 to your computer and use it in GitHub Desktop.
PyTorch如何列出分類錯誤之原始圖片路徑?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomDataset(torch.utils.data.Dataset): | |
def __init__(self, dataframe, transform): | |
self.dataframe = dataframe | |
self.transform = transform | |
def __len__(self): | |
return len(self.dataframe) | |
def __getitem__(self, index): | |
row = self.dataframe.iloc[index] | |
image = self.transform(Image.open((row["file_path"]))) | |
label = np.asarray(row["class"]) | |
file_path = row["file_path"] | |
return (image, label, file_path) | |
# val_loader使用CustomDataset這個class | |
def testing(UnNormalize, writer, val_loader, model, criterion, args, size_val_df, y_pred, y_true): | |
batch_time = AverageMeter('Time', ':6.3f') | |
losses = AverageMeter('Loss', ':.4e') | |
top1 = AverageMeter('Acc@1', ':6.2f') | |
top5 = AverageMeter('Acc@5', ':6.2f') | |
progress = ProgressMeter( | |
len(val_loader), | |
[batch_time, losses, top1, top5], | |
prefix='Test: ') | |
# switch to evaluate mode | |
model.eval() | |
val_running_loss = 0.0 | |
val_loss_values = [] | |
val_running_accuracy = 0 | |
val_accuracy_values = [] | |
incorrect_examples = [] | |
with torch.no_grad(): | |
end = time.time() | |
for i, (images, target, path) in enumerate(val_loader): | |
if args.gpu is not None: | |
images = images.cuda(args.gpu, non_blocking=True) | |
if torch.cuda.is_available(): | |
target = target.cuda(args.gpu, non_blocking=True) | |
# compute output | |
output = model(images) | |
_, preds = torch.max(output, 1) | |
loss = criterion(output, target) | |
pred_index = output.data.max(1)[1] # get the index of the max log-probability | |
if args.batch_size==1 and pred_index.cpu() != target.cpu(): | |
with open("missclassified_file_path.txt", "a") as f: | |
true_label = str(target.cpu().tolist()[0]) | |
pred_label = str(pred_index.cpu().tolist()[0]) | |
write_content = ' '.join(path) | |
write_content = write_content + ' ' + true_label + ' ' + pred_label # 路徑, 實際, 預測 | |
f.write(write_content+"\n") | |
y_pred.extend(preds.view(-1).detach().cpu().numpy()) | |
y_true.extend(target.view(-1).detach().cpu().numpy()) | |
# measure accuracy and record loss | |
acc1, acc5 = accuracy(output, target, topk=(1, 5)) | |
losses.update(loss.item(), images.size(0)) | |
top1.update(acc1[0], images.size(0)) | |
top5.update(acc5[0], images.size(0)) | |
val_running_loss += loss.item() * images.size(0) | |
val_loss_values.append(val_running_loss / size_val_df) | |
val_running_accuracy += torch.sum(preds == target.data) | |
val_accuracy_values.append(val_running_accuracy.double() / size_val_df) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
test_data_recorder(UnNormalize, i, pred_index, writer, target, images, output, i, val_loader) | |
if i % args.print_freq == 0: | |
progress.display(i) | |
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' | |
.format(top1=top1, top5=top5)) | |
return top1.avg, val_loss_values[-1], val_accuracy_values[-1], y_pred, y_true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment