Skip to content

Instantly share code, notes, and snippets.

@DanielTakeshi
Last active November 7, 2018 18:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DanielTakeshi/c2a5ddad85dc3c938c9c61441e769db4 to your computer and use it in GitHub Desktop.
Save DanielTakeshi/c2a5ddad85dc3c938c9c61441e769db4 to your computer and use it in GitHub Desktop.
import copy, cv2, os, sys, pickle, time
import numpy as np
from os.path import join
TARGET = 'tmp/'
RAW_PICKLE_FILE = 'data_raw_115_items.pkl'
def prepare_data():
"""Create the appropriate data for PyTorch using `ImageFolder`. From:
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
https://pytorch.org/docs/stable/torchvision/datasets.html?highlight=imagefolder#torchvision.datasets.ImageFolder
https://discuss.pytorch.org/t/questions-about-imagefolder/774
If you need to call this again, delete the TARGET directory.
"""
assert not os.path.exists(TARGET), "target directory exists:\n\t{}".format(TARGET)
os.makedirs(TARGET)
paths = ['train', 'valid']
path_train = join(TARGET,paths[0])
path_valid = join(TARGET,paths[1])
for p in paths:
os.makedirs(join(TARGET,p))
os.makedirs(join(TARGET,p,'success'))
os.makedirs(join(TARGET,p,'failure'))
t_success = 0
t_failure = 0
# Put all numbers here. For PyTorch we can use one scalar for each of the
# mean and std, because we have one scalar here (for our depth images),
# whose values are 'triplicated' across all three channels.
numbers = []
with open(RAW_PICKLE_FILE, 'r') as fh:
data = pickle.load(fh)
# Pick validation indices.
N = len(data)
indx_random = np.random.permutation(N)
indx_train = indx_random[ : int(N*0.8)]
indx_valid = indx_random[int(N*0.8) : ]
# Each `item` here has a 'd_img' key, and a class label 'class' key.
for idx,item in enumerate(data):
if idx in indx_train:
pname = path_train
else:
pname = path_valid
if item['class'] == 0:
png_name = join(pname, 'success', 'd_{}.png'.format(str(idx).zfill(4)))
t_success += 1
elif item['class'] == 1:
png_name = join(pname, 'failure', 'd_{}.png'.format(str(idx).zfill(4)))
t_failure += 1
else:
raise ValueError(item['class'])
cv2.imwrite(png_name, item['d_img'])
# Accumulate statistics for mean and std computation across our
# lone channel. We made values same across all three channels.
d_img = item['d_img']
assert d_img.shape == (480,640,3)
assert np.sum(d_img[:,:,0]) == np.sum(d_img[:,:,1]) == np.sum(d_img[:,:,2])
numbers.extend( d_img[:,:,0].flatten() )
print("done loading data, success {} vs failure {} (total {})".format(
t_success, t_failure, N))
numbers = np.array(numbers)
print("len(numbers): {} (has the single-channel mean/std info)".format(len(numbers)))
print("mean(numbers): {}".format(np.mean(numbers)))
print("std(numbers): {}".format(np.std(numbers)))
print("\nBut, use this for actual mean/std because we want them in [0,256) ...")
print("mean(scaled): {}".format(np.mean(numbers/256.0)))
print("std(scaled): {}".format(np.std(numbers/256.0)))
if __name__ == "__main__":
prepare_data()
@DanielTakeshi
Copy link
Author

DanielTakeshi commented Nov 7, 2018

The data is here in a standard pickle file: https://drive.google.com/open?id=1UoBjkmqMkijQ95eUOfmtuSGKBHQjS6n9 (624 MB)

Put it in the same directory as this script, build_data.py.

Run python build_data.py and the output is:

done loading data, success 64 vs failure 51 (total 115)
len(numbers):  35328000  (has the single-channel mean/std info)
mean(numbers): 93.8304761096
std(numbers):  84.9985507432

But, use this for actual mean/std because we want them in [0,256) ...
mean(scaled): 0.36796265141
std(scaled):  0.333327649973

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment