Skip to content

Instantly share code, notes, and snippets.

@Splines
Created October 3, 2021 01:23
Show Gist options
  • Save Splines/90eb505fc5e02821fe30575ac9d38587 to your computer and use it in GitHub Desktop.
Save Splines/90eb505fc5e02821fe30575ac9d38587 to your computer and use it in GitHub Desktop.
Parse the MNIST train set
# 🌐 MNIST Dataset
# https://deepai.org/dataset/mnist
import numpy as np
from matplotlib import pyplot as plt
with open('./mnist/train-images.idx3-ubyte', 'rb') as images_file,\
open('./mnist/train-labels.idx1-ubyte', 'rb') as labels_file:
# --- Images Header
# 4x 32-bit big-endian integer
images_header = np.fromfile(images_file, dtype='>i4', count=4)
images_magic_number = images_header[0]
images_count = images_header[1]
row_count = images_header[2]
col_count = images_header[3]
pixel_count = row_count*col_count
print(f'magic number\t {images_magic_number}')
print(f'#images\t\t {images_count}')
print(f'#rows\t\t {row_count}')
print(f'#cols\t\t {col_count}')
print(f'#pixels\t\t {pixel_count} ({row_count}*{col_count})')
# --- Images Data
images_buffer = images_file.read(pixel_count * images_count)
images = np.frombuffer(images_buffer, dtype=np.uint8)
images = images.reshape(images_count, pixel_count)
# --- Labels Header
# 2x 32-bit big-endian integer
labels_header = np.fromfile(labels_file, dtype='>i4', count=2)
labels_magic_number = labels_header[0]
label_count = labels_header[1]
print(f'magic number\t {labels_magic_number}')
print(f"#labels\t\t {label_count}")
# --- Labels Data
labels_buffer = labels_file.read(label_count)
labels = np.frombuffer(labels_buffer, dtype=np.uint8)
# --- Plot image
img = images[0].reshape((28, 28))
plt.imshow(img, cmap='gray')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment