Skip to content

Instantly share code, notes, and snippets.

@akki2825
Last active June 7, 2017 16:49
Show Gist options
  • Save akki2825/e31f146039dca1147e346c553de761f5 to your computer and use it in GitHub Desktop.
Save akki2825/e31f146039dca1147e346c553de761f5 to your computer and use it in GitHub Desktop.
# These are all the modules we'll be using later. Make sure you can import them
# before proceeding further.
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import hashlib
import json
import sys
import tarfile
from IPython.display import display, Image
from scipy import ndimage
from sklearn.linear_model import LogisticRegression
from six.moves.urllib.request import urlretrieve
from six.moves import cPickle as pickle
# Config the matlotlib backend as plotting inline in IPython
%matplotlib inline
url = 'http://commondatastorage.googleapis.com/books1000/'
def maybe_download(filename, expected_bytes, force=False):
"""Download a file if not present, and make sure it's the right size."""
if force or not os.path.exists(filename):
filename, _ = urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == expected_bytes:
print('Found and verified', filename)
else:
raise Exception(
'Failed to verify ' + filename + '. Can you get to it with a browser?')
return filename
train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)
test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)
num_classes = 10
np.random.seed(133)
def maybe_extract(filename, force=False):
root = os.path.splitext(os.path.splitext(filename)[0])[0] # remove .tar.gz
if os.path.isdir(root) and not force:
# You may override by setting force=True.
print('%s already present - Skipping extraction of %s.' % (root, filename))
else:
print('Extracting data for %s. This may take a while. Please wait.' % root)
tar = tarfile.open(filename)
sys.stdout.flush()
tar.extractall()
tar.close()
data_folders = [
os.path.join(root, d) for d in sorted(os.listdir(root))
if os.path.isdir(os.path.join(root, d))]
if len(data_folders) != num_classes:
raise Exception(
'Expected %d folders, one per class. Found %d instead.' % (
num_classes, len(data_folders)))
print(data_folders)
return data_folders
train_folders = maybe_extract(train_filename)
test_folders = maybe_extract(test_filename)
pickle_file = train_datasets[0] # index 0 should be all As, 1 = all Bs, etc.
with open(pickle_file, 'rb') as f:
letter_set = pickle.load(f) # unpickle
sample_idx = np.random.randint(len(letter_set)) # pick a random image index
sample_image = letter_set[sample_idx, :, :] # extract a 2D slice
plt.figure()
plt.imshow(sample_image) # display it
image_size = 28 # Pixel width and height.
pixel_depth = 255.0 # Number of levels per pixel.
def load_letter(folder, min_num_images):
"""Load the data for a single letter label."""
image_files = os.listdir(folder)
dataset = np.ndarray(shape=(len(image_files), image_size, image_size),
dtype=np.float32)
print(folder)
for image_index, image in enumerate(image_files):
image_file = os.path.join(folder, image)
try:
image_data = (ndimage.imread(image_file).astype(float) -
pixel_depth / 2) / pixel_depth
if image_data.shape != (image_size, image_size):
raise Exception('Unexpected image shape: %s' % str(image_data.shape))
dataset[image_index, :, :] = image_data
except IOError as e:
print('Could not read:', image_file, ':', e, '- it\'s ok, skipping.')
num_images = image_index + 1
dataset = dataset[0:num_images, :, :]
if num_images < min_num_images:
raise Exception('Many fewer images than expected: %d < %d' %
(num_images, min_num_images))
print('Full dataset tensor:', dataset.shape)
print('Mean:', np.mean(dataset))
print('Standard deviation:', np.std(dataset))
return dataset
def maybe_pickle(data_folders, min_num_images_per_class, force=False):
dataset_names = []
for folder in data_folders:
set_filename = folder + '.pickle'
dataset_names.append(set_filename)
if os.path.exists(set_filename) and not force:
# You may override by setting force=True.
print('%s already present - Skipping pickling.' % set_filename)
else:
print('Pickling %s.' % set_filename)
dataset = load_letter(folder, min_num_images_per_class)
try:
with open(set_filename, 'wb') as f:
pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
except Exception as e:
print('Unable to save data to', set_filename, ':', e)
return dataset_names
def make_arrays(nb_rows, img_size):
if nb_rows:
dataset = np.ndarray((nb_rows, img_size, img_size), dtype=np.float32)
labels = np.ndarray(nb_rows, dtype=np.int32)
else:
dataset, labels = None, None
return dataset, labels
def merge_datasets(pickle_files, train_size, valid_size=0):
num_classes = len(pickle_files)
valid_dataset, valid_labels = make_arrays(valid_size, image_size)
train_dataset, train_labels = make_arrays(train_size, image_size)
vsize_per_class = valid_size // num_classes
tsize_per_class = train_size // num_classes
start_v, start_t = 0, 0
end_v, end_t = vsize_per_class, tsize_per_class
end_l = vsize_per_class+tsize_per_class
for label, pickle_file in enumerate(pickle_files):
try:
with open(pickle_file, 'rb') as f:
letter_set = pickle.load(f)
# let's shuffle the letters to have random validation and training set
np.random.shuffle(letter_set)
if valid_dataset is not None:
valid_letter = letter_set[:vsize_per_class, :, :]
valid_dataset[start_v:end_v, :, :] = valid_letter
valid_labels[start_v:end_v] = label
start_v += vsize_per_class
end_v += vsize_per_class
train_letter = letter_set[vsize_per_class:end_l, :, :]
train_dataset[start_t:end_t, :, :] = train_letter
train_labels[start_t:end_t] = label
start_t += tsize_per_class
end_t += tsize_per_class
except Exception as e:
print('Unable to process data from', pickle_file, ':', e)
raise
return valid_dataset, valid_labels, train_dataset, train_labels
train_size = 200000
valid_size = 10000
test_size = 10000
valid_dataset, valid_labels, train_dataset, train_labels = merge_datasets(
train_datasets, train_size, valid_size)
_, _, test_dataset, test_labels = merge_datasets(test_datasets, test_size)
print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_dataset.shape, valid_labels.shape)
print('Testing:', test_dataset.shape, test_labels.shape)
def randomize(dataset, labels):
permutation = np.random.permutation(labels.shape[0])
shuffled_dataset = dataset[permutation,:,:]
shuffled_labels = labels[permutation]
return shuffled_dataset, shuffled_labels
train_dataset, train_labels = randomize(train_dataset, train_labels)
test_dataset, test_labels = randomize(test_dataset, test_labels)
valid_dataset, valid_labels = randomize(valid_dataset, valid_labels)
pickle_file = 'notMNIST.pickle'
try:
f = open(pickle_file, 'wb')
save = {
'train_dataset': train_dataset,
'train_labels': train_labels,
'valid_dataset': valid_dataset,
'valid_labels': valid_labels,
'test_dataset': test_dataset,
'test_labels': test_labels,
}
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
f.close()
except Exception as e:
print('Unable to save data to', pickle_file, ':', e)
raise
statinfo = os.stat(pickle_file)
print('Compressed pickle size:', statinfo.st_size)
all_data = pickle.load(open('notMNIST.pickle', 'rb'))
def count_duplicates(dataset1, dataset2):
hashes = [hashlib.sha1(x).hexdigest() for x in dataset1]
dup_indices = []
for i in range(0, len(dataset2)):
if hashlib.sha1(dataset2[i]).hexdigest() in hashes:
dup_indices.append(i)
return len(dup_indices)
print(count_duplicates(all_data['test_dataset'], all_data['valid_dataset']))
print(count_duplicates(all_data['valid_dataset'], all_data['train_dataset']))
print(count_duplicates(all_data['test_dataset'], all_data['train_dataset']))
train_dataset = all_data['train_dataset']
train_labels = all_data['train_labels']
test_dataset = all_data['test_dataset']
test_labels = all_data['test_labels']
def get_score(train_dataset, train_labels, test_dataset, test_labels):
model = LogisticRegression()
train_flatten_dataset = np.array([x.flatten() for x in train_dataset])
test_flatten_dataset = np.array([x.flatten() for x in test_dataset])
model.fit(train_flatten_dataset, train_labels)
return model.score([x.flatten() for x in test_dataset], test_labels)
print("100 trainsamples score: " + str(get_score(train_dataset[:100], train_labels[:100], test_dataset, test_labels)))
print("1000 trainsamples score: " + str(get_score(train_dataset[:1000], train_labels[:1000], test_dataset, test_labels)))
print("5000 trainsamples score: " + str(get_score(train_dataset[:5000], train_labels[:5000], test_dataset, test_labels)))
print("10000 trainsamples score: " + str(get_score(train_dataset[:10000], train_labels[:10000], test_dataset, test_labels)))
print("all trainsamples score: " + str(get_score(train_dataset, train_labels, test_dataset, test_labels)))
wp_embed_register_handler( 'gist', '/https?:\/\/gist\.github\.com\/([a-z0-9]+)(\?file=.*)?/i', 'bhww_embed_handler_gist' );
function bhww_embed_handler_gist( $matches, $attr, $url, $rawattr ) {
$embed = sprintf(
'<script src="https://gist.github.com/%1$s.js%2$s"></script>',
esc_attr($matches[1]),
esc_attr($matches[2])
);
return apply_filters( 'embed_gist', $embed, $matches, $attr, $url, $rawattr );
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment