Last active
October 18, 2023 07:11
-
-
Save dirvuk/9e8373e4b309dfb9ed337f0d9a525fc0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from glob import glob | |
import numpy as np | |
from matplotlib.pyplot import imread | |
from sklearn.model_selection import train_test_split | |
def load_notmnist( | |
path='./notMNIST_small', | |
letters='ABCDEFGHIJ', | |
img_shape=(28,28), | |
test_size=0.25, | |
one_hot=False, | |
): | |
# download data if it's missing. If you have any problems, go to the urls and load it manually. | |
if not os.path.exists(path): | |
if not os.path.exists('./notMNIST_small.tar.gz'): | |
print("Downloading data...") | |
assert os.system('curl http://yaroslavvb.com/upload/notMNIST/notMNIST_small.tar.gz > notMNIST_small.tar.gz') == 0 | |
print("Extracting ...") | |
assert os.system('tar -zxvf notMNIST_small.tar.gz > untar_notmnist.log') == 0 | |
data,labels = [],[] | |
print("Parsing...") | |
for img_path in glob(os.path.join(path,'*/*')): | |
class_i = img_path.split(os.sep)[-2] | |
if class_i not in letters: continue | |
try: | |
data.append(imread(img_path)) | |
labels.append(class_i,) | |
except: | |
print("found broken img: %s [it's ok if <10 images are broken]" % img_path) | |
data = np.stack(data)[:,None].astype('float32') | |
data = (data - np.mean(data)) / np.std(data) | |
#convert classes to ints | |
letter_to_i = {l:i for i,l in enumerate(letters)} | |
labels = np.array(list(map(letter_to_i.get, labels))) | |
if one_hot: | |
labels = (np.arange(np.max(labels) + 1)[None,:] == labels[:, None]).astype('float32') | |
#split into train/test | |
if test_size == 0: | |
X_train, X_test, y_train, y_test = data, [], labels, [] | |
else: | |
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=test_size, random_state=42) | |
print("Done") | |
return X_train, y_train, X_test, y_test |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment