Last active
June 18, 2018 09:44
-
-
Save siakon89/4191792fe0a04cc4f1085d8419ceb85d to your computer and use it in GitHub Desktop.
function to extract data from the cifar-10 dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import numpy as np | |
from os import listdir | |
from os.path import isfile, join | |
import os | |
# Function to unpickle the dataset | |
def unpickle_all_data(directory): | |
# Initialize the variables | |
train = dict() | |
test = dict() | |
train_x = [] | |
train_y = [] | |
test_x = [] | |
test_y = [] | |
# Iterate through all files that we want, train and test | |
# Train is separated into batches | |
for filename in listdir(directory): | |
if isfile(join(directory, filename)): | |
# The train data | |
if 'data_batch' in filename: | |
print('Handing file: %s' % filename) | |
# Opent the file | |
with open(directory + '/' + filename, 'rb') as fo: | |
data = pickle.load(fo, encoding='bytes') | |
if 'data' not in train: | |
train['data'] = data[b'data'] | |
train['labels'] = np.array(data[b'labels']) | |
else: | |
train['data'] = np.concatenate((train['data'], data[b'data'])) | |
train['labels'] = np.concatenate((train['labels'], data[b'labels'])) | |
# The test data | |
elif 'test_batch' in filename: | |
print('Handing file: %s' % filename) | |
# Open the file | |
with open(directory + '/' + filename, 'rb') as fo: | |
data = pickle.load(fo, encoding='bytes') | |
test['data'] = data[b'data'] | |
test['labels'] = data[b'labels'] | |
# Manipulate the data to the propper format | |
for image in train['data']: | |
train_x.append(np.transpose(np.reshape(image,(3, 32,32)), (1,2,0))) | |
train_y = [label for label in train['labels']] | |
for image in test['data']: | |
test_x.append(np.transpose(np.reshape(image,(3, 32,32)), (1,2,0))) | |
test_y = [label for label in test['labels']] | |
# Transform the data to np array format | |
train_x = np.array(train_x) | |
train_y = np.array(train_y) | |
test_x = np.array(test_x) | |
test_y = np.array(test_y) | |
return (train_x, train_y), (test_x, test_y) | |
# Run the function with and include the folder where the data are | |
(x_train, y_train), (x_test, y_test) = unpickle_all_data(os.getcwd() + '/cifar-10-batches-py/') | |
with open('data/validation/test-x', 'wb') as handle: | |
pickle.dump(x_test, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
with open('data/validation/test-y', 'wb') as handle: | |
pickle.dump(y_test, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
with open('data/train/train-x', 'wb') as handle: | |
pickle.dump(x_train, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
with open('data/train/train-y', 'wb') as handle: | |
pickle.dump(y_train, handle, protocol=pickle.HIGHEST_PROTOCOL) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment