Created
May 11, 2018 07:27
-
-
Save bveliqi/5efe7d20c99025d02df87e4c595711c1 to your computer and use it in GitHub Desktop.
PyTorch - Tiny-ImageNet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision.datasets as datasets | |
import torch.utils.data as data | |
import torchvision.transforms as transforms | |
from logger import logger | |
from torch.autograd import Variable | |
from train import ResNet, Bottleneck | |
# --- MAIN --- | |
if __name__ == "__main__": | |
# load model | |
logger.info("Loading stored model ...") | |
model = ResNet(Bottleneck, [3, 4, 6, 3]) | |
model.load_state_dict(torch.load('/models/baseline-resnet50.pt')) | |
logger.info("Loaded model successfully.") | |
# load data set | |
logger.info("Reading data...") | |
val_dir = 'data/tiny-imagenet-200/val' | |
val_dataset = datasets.ImageFolder(val_dir, transform=transforms.ToTensor()) | |
val_loader = data.DataLoader(val_dataset, batch_size=32) | |
logger.info("Loaded: %s", val_dir) | |
correct = 0 | |
total = 0 | |
for data in val_loader: | |
images, labels = data | |
outputs = model(Variable(images)) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum() | |
logger.info('Progress --- total: %s, correct: %s', total, correct) | |
logger.info('Accuracy of the network on the 10000 test images: %s %%', (100 * correct / total)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment