Last active
November 23, 2018 16:27
-
-
Save gngdb/c5855e10dea83c99a44b338acc76759f to your computer and use it in GitHub Desktop.
Updated this script to run on up to date PyTorch: https://github.com/szagoruyko/functional-zoo/blob/master/imagenet-validation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# dependencies | |
import re | |
import os | |
import argparse | |
import torch | |
from tqdm import tqdm | |
#import cv2 | |
import numpy as np | |
import torch.utils.data | |
import torchnet as tnt | |
from torchvision import transforms as cvtransforms | |
import torchvision.datasets as datasets | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torch.utils import model_zoo | |
# input arguments | |
parser = argparse.ArgumentParser(description = 'PyTorch ImageNet validation') | |
parser.add_argument('--imagenetpath', metavar='PATH', required=True, | |
help='path to dataset') | |
parser.add_argument('--numthreads', default=4, type=int, metavar='N', | |
help='number of data loading threads (default: 4)') | |
parser.add_argument('--model', metavar='PATH', required=True, | |
help='path to model') | |
# functional model definition from functional zoo: https://github.com/szagoruyko/functional-zoo/blob/master/imagenet-validation.py#L27-L47 | |
def define_model(params): | |
def conv2d(input, params, base, stride=1, pad=0): | |
return F.conv2d(input, params[base + '.weight'], | |
params[base + '.bias'], stride, pad) | |
def group(input, params, base, stride, n): | |
o = input | |
for i in range(0,n): | |
b_base = ('%s.block%d.conv') % (base, i) | |
x = o | |
o = conv2d(x, params, b_base + '0') | |
o = F.relu(o) | |
o = conv2d(o, params, b_base + '1', stride=i==0 and stride or 1, pad=1) | |
o = F.relu(o) | |
o = conv2d(o, params, b_base + '2') | |
if i == 0: | |
o += conv2d(x, params, b_base + '_dim', stride=stride) | |
else: | |
o += x | |
o = F.relu(o) | |
return o | |
# determine network size by parameters | |
blocks = [sum([re.match('group%d.block\d+.conv0.weight'%j, k) is not None | |
for k in params.keys()]) for j in range(4)] | |
def f(input, params): | |
o = F.conv2d(input, params['conv0.weight'], params['conv0.bias'], 2, 3) | |
o = F.relu(o) | |
o = F.max_pool2d(o, 3, 2, 1) | |
o_g0 = group(o, params, 'group0', 1, blocks[0]) | |
o_g1 = group(o_g0, params, 'group1', 2, blocks[1]) | |
o_g2 = group(o_g1, params, 'group2', 2, blocks[2]) | |
o_g3 = group(o_g2, params, 'group3', 2, blocks[3]) | |
o = F.avg_pool2d(o_g3, 7, 1, 0) | |
o = o.view(o.size(0), -1) | |
o = F.linear(o, params['fc.weight'], params['fc.bias']) | |
return o | |
return f | |
def main(): | |
# parse input arguments | |
args = parser.parse_args() | |
#def cvload(path): | |
# img = cv2.imread(path, cv2.IMREAD_COLOR) | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# return img | |
# set up data loader | |
print("| setting up data loader...") | |
valdir = os.path.join(args.imagenetpath, 'val') | |
ds = datasets.ImageFolder(valdir, tnt.transform.compose([ | |
cvtransforms.Scale(256), | |
cvtransforms.CenterCrop(224), | |
#lambda x: x.astype(np.float32) / 255.0, | |
cvtransforms.ToTensor(), | |
cvtransforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
#lambda x: x.transpose(2,0,1).astype(np.float32), | |
#torch.from_numpy, | |
]))#, loader = cvload) | |
train_loader = torch.utils.data.DataLoader(ds, | |
batch_size=256, shuffle=False, | |
num_workers=args.numthreads, pin_memory=False) | |
params = model_zoo.load_url('https://s3.amazonaws.com/modelzoo-networks/wide-resnet-50-2-export-5ae25d50.pth') | |
params = {k:params[k].cuda() for k in params} | |
f = define_model(params) | |
class_err = tnt.meter.ClassErrorMeter(topk=[1,5], accuracy=True) | |
for sample in tqdm(train_loader): | |
with torch.no_grad(): | |
inputs = Variable(sample[0].cuda()) | |
targets = sample[1] | |
class_err.add(f(inputs, params).data, targets) | |
print('Validation top1/top5 accuracy:') | |
print(class_err.value()) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment