Skip to content

Instantly share code, notes, and snippets.

@soumith
Created October 26, 2018 15:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soumith/3647676aff805fb1fb82122cbcd4ec79 to your computer and use it in GitHub Desktop.
Save soumith/3647676aff805fb1fb82122cbcd4ec79 to your computer and use it in GitHub Desktop.
diff --git a/imagenet/main.py b/imagenet/main.py
index 20838f0..783bbf2 100644
--- a/imagenet/main.py
+++ b/imagenet/main.py
@@ -20,8 +20,6 @@ model_names = sorted(name for name in models.__dict__
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
-parser.add_argument('data', metavar='DIR',
- help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
@@ -111,14 +109,16 @@ def main():
cudnn.benchmark = True
# Data loading code
- traindir = os.path.join(args.data, 'train')
- valdir = os.path.join(args.data, 'val')
+ # traindir = os.path.join(args.data, 'train')
+ # valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
- train_dataset = datasets.ImageFolder(
- traindir,
- transforms.Compose([
+ train_dataset = datasets.FakeData(
+ size=1200000, # imagenet training size
+ num_classes=1000, # imagenet number of classes
+ image_size=(3, 224, 224),
+ transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
@@ -135,7 +135,11 @@ def main():
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
- datasets.ImageFolder(valdir, transforms.Compose([
+ datasets.FakeData(
+ size=50000, # imagenet training size
+ num_classes=1000, # imagenet number of classes
+ image_size=(3, 224, 224),
+ transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
@@ -186,7 +190,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
# measure data loading time
data_time.update(time.time() - end)
- target = target.cuda(non_blocking=True)
+ target = target.cuda(non_blocking=True).to(dtype=torch.int64)
# compute output
output = model(input)
@@ -230,7 +234,7 @@ def validate(val_loader, model, criterion):
with torch.no_grad():
end = time.time()
for i, (input, target) in enumerate(val_loader):
- target = target.cuda(non_blocking=True)
+ target = target.cuda(non_blocking=True).to(dtype=torch.int64)
# compute output
output = model(input)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment