Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created July 12, 2021 10:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/d22b047daaaf8e761924aa58de970012 to your computer and use it in GitHub Desktop.
Save e96031413/d22b047daaaf8e761924aa58de970012 to your computer and use it in GitHub Desktop.
PyTorch如何列出分類錯誤之原始圖片路徑?
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, dataframe, transform):
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, index):
row = self.dataframe.iloc[index]
image = self.transform(Image.open((row["file_path"])))
label = np.asarray(row["class"])
file_path = row["file_path"]
return (image, label, file_path)
# val_loader使用CustomDataset這個class
def testing(UnNormalize, writer, val_loader, model, criterion, args, size_val_df, y_pred, y_true):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
val_running_loss = 0.0
val_loss_values = []
val_running_accuracy = 0
val_accuracy_values = []
incorrect_examples = []
with torch.no_grad():
end = time.time()
for i, (images, target, path) in enumerate(val_loader):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
# compute output
output = model(images)
_, preds = torch.max(output, 1)
loss = criterion(output, target)
pred_index = output.data.max(1)[1] # get the index of the max log-probability
if args.batch_size==1 and pred_index.cpu() != target.cpu():
with open("missclassified_file_path.txt", "a") as f:
true_label = str(target.cpu().tolist()[0])
pred_label = str(pred_index.cpu().tolist()[0])
write_content = ' '.join(path)
write_content = write_content + ' ' + true_label + ' ' + pred_label # 路徑, 實際, 預測
f.write(write_content+"\n")
y_pred.extend(preds.view(-1).detach().cpu().numpy())
y_true.extend(target.view(-1).detach().cpu().numpy())
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
val_running_loss += loss.item() * images.size(0)
val_loss_values.append(val_running_loss / size_val_df)
val_running_accuracy += torch.sum(preds == target.data)
val_accuracy_values.append(val_running_accuracy.double() / size_val_df)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
test_data_recorder(UnNormalize, i, pred_index, writer, target, images, output, i, val_loader)
if i % args.print_freq == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg, val_loss_values[-1], val_accuracy_values[-1], y_pred, y_true
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment