Skip to content

Instantly share code, notes, and snippets.

@xu-ji
Created February 6, 2020 17:19
Show Gist options
  • Save xu-ji/22f59ce6efb257feb238a33b089a577a to your computer and use it in GitHub Desktop.
Save xu-ji/22f59ce6efb257feb238a33b089a577a to your computer and use it in GitHub Desktop.
IIC MNIST inference
from code.utils.cluster.cluster_eval import cluster_eval
from code.utils.cluster.data import cluster_twohead_create_dataloaders
import pickle
import code.archs as archs
import torch
from datetime import datetime
import sys
config_in = open("/scratch/shared/nfs1/xuji/iid_private/685/config.pickle", "rb")
config = pickle.load(config_in)
net = archs.__dict__[config.arch](config)
net.load_state_dict(torch.load("/scratch/shared/nfs1/xuji/iid_private/685/best_net.pytorch"))
net.cuda()
#config.batch_sz = 100
dataloaders_head_A, dataloaders_head_B, \
mapping_assignment_dataloader, mapping_test_dataloader = \
cluster_twohead_create_dataloaders(config)
if "MNIST" in config.dataset:
sobel = False
else:
sobel = True
"""
cluster_eval(config, net,
mapping_assignment_dataloader=mapping_assignment_dataloader,
mapping_test_dataloader=mapping_test_dataloader,
sobel=sobel, print_stats=True)
"""
best_head = 0
mappings = dict([(0, 9), (1, 3), (2, 1), (3, 4), (4, 7), (5, 8), (6, 5), (7, 6), (8, 0), (9, 2)])
acc = 0.
ct = 0
net.eval()
for i, batch in enumerate(mapping_test_dataloader):
print("batch %d %s" % (i, datetime.now()))
sys.stdout.flush()
imgs = batch[0].cuda()
labels = batch[1].cuda()
preds = net(imgs)[best_head]
preds_flat = preds.argmax(dim=1, keepdim=False)
preds_flat_reordered = torch.zeros(preds_flat.shape, dtype=preds_flat.dtype).cuda()
for pred_c, target_c in mappings.iteritems():
samples_pred_c = preds_flat == pred_c
preds_flat_reordered[samples_pred_c] = target_c
acc += (preds_flat_reordered == labels).sum().item()
ct += labels.shape[0]
print((acc, ct)) # prints (69477.0, 70000)
print(acc/ float(ct)) # prints 0.992528571429
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment