Skip to content

Instantly share code, notes, and snippets.

@bveliqi
Created May 11, 2018 07:27
Show Gist options
  • Save bveliqi/5efe7d20c99025d02df87e4c595711c1 to your computer and use it in GitHub Desktop.
Save bveliqi/5efe7d20c99025d02df87e4c595711c1 to your computer and use it in GitHub Desktop.
PyTorch - Tiny-ImageNet
import torch
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from logger import logger
from torch.autograd import Variable
from train import ResNet, Bottleneck
# --- MAIN ---
if __name__ == "__main__":
# load model
logger.info("Loading stored model ...")
model = ResNet(Bottleneck, [3, 4, 6, 3])
model.load_state_dict(torch.load('/models/baseline-resnet50.pt'))
logger.info("Loaded model successfully.")
# load data set
logger.info("Reading data...")
val_dir = 'data/tiny-imagenet-200/val'
val_dataset = datasets.ImageFolder(val_dir, transform=transforms.ToTensor())
val_loader = data.DataLoader(val_dataset, batch_size=32)
logger.info("Loaded: %s", val_dir)
correct = 0
total = 0
for data in val_loader:
images, labels = data
outputs = model(Variable(images))
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
logger.info('Progress --- total: %s, correct: %s', total, correct)
logger.info('Accuracy of the network on the 10000 test images: %s %%', (100 * correct / total))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment