Skip to content

Instantly share code, notes, and snippets.

@davidbau
Last active February 21, 2018 15:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save davidbau/4ef0c2eca8430cd167fb045a8a474663 to your computer and use it in GitHub Desktop.
Save davidbau/4ef0c2eca8430cd167fb045a8a474663 to your computer and use it in GitHub Desktop.
Script for downloading and formatting miniplaces in pytorch ImageFolder format
#!/usr/bin/env python2.7
# Script to create simple flat pytorch ImageFolder folder hierarchy
# of training and validation images for miniplaces. Each category
# name is just a folder name (numbered in alphabetical order as in
# the original miniplaces), and both train and val images are places
# directly inside a single level of folders with the flat cateogry names.
import shutil, os, tarfile
def ensure_dir(dirname):
try:
os.makedirs(dirname)
except:
pass
# Download and untar data files.
gitdir = 'https://raw.githubusercontent.com/CSAILVision/miniplaces/master/data'
urls = [
'%s/val.txt' % gitdir,
'%s/train.txt' % gitdir,
'%s/categories.txt' % gitdir,
'%s/object_categories.txt' % gitdir,
'http://miniplaces.csail.mit.edu/data/data.tar.gz',
]
ensure_dir('raw')
ensure_dir('data')
# python2 vs 3
try:
import urllib
urlopen = urllib.request.urlopen
except:
import urllib2
urlopen = urllib2.urlopen
for url in urls:
filename = url.rpartition('/')[2]
file_path = os.path.join('raw', filename)
if not os.path.exists(file_path):
print('Downloading %s' % url)
data = urlopen(url)
with open(file_path, 'wb') as f:
f.write(data.read())
if file_path.endswith('.tar.gz'):
with tarfile.open(file_path) as tar:
print('Untarring %s' % file_path)
tar.extractall('data')
# os.unlink(file_path)
else:
shutil.copyfile(file_path,
os.path.join('data', os.path.basename(file_path)))
# Now copy them into simple pytorch data format.
ensure_dir('simple/train')
ensure_dir('simple/val')
# Copy the train images to flat category directory names
categories = []
trainfiles = []
for root, dirs, files in os.walk("data/images/train"):
files = [f for f in files if f.endswith('.jpg')]
if not files:
continue
catname = '-'.join(root.split('/')[4:])
categories.append(catname)
ensure_dir('simple/train/%s' % catname)
ensure_dir('simple/val/%s' % catname)
print('Copying train/%s' % catname)
for f in files:
target = 'train/%s/%s' % (catname, f)
trainfiles.append(target)
shutil.copyfile(os.path.join(root, f), os.path.join('simple', target))
categories.sort()
# Save a file listing all images, which can be used to speed loading.
trainfiles.sort()
with open('simple/train.txt', 'w') as f:
for filename in trainfiles:
f.write('%s\n' % filename)
# Copy the val images to the same flat category directory names.
valfiles = []
with open('data/val.txt') as f:
for line in f.readlines():
fn, catnum = line.strip().split()
basename = os.path.basename(fn)
catname = categories[int(catnum)]
target = 'val/%s/%s' % (catname, basename)
print('Copying %s' % target)
valfiles.append(target)
shutil.copyfile(os.path.join('data/images', fn),
os.path.join('simple', target))
# Save a file listing all images, which can be used to speed loading.
valfiles.sort()
with open('simple/val.txt', 'w') as f:
for filename in valfiles:
f.write('%s\n' % filename)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment