Skip to content

Instantly share code, notes, and snippets.

@jimmyahacker
Created July 8, 2018 18:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jimmyahacker/393bc74b67e24b2eef663ca28244e46b to your computer and use it in GitHub Desktop.
Save jimmyahacker/393bc74b67e24b2eef663ca28244e46b to your computer and use it in GitHub Desktop.
How to extract data from idx file
import numpy as np
def _read(dimensions, stream):
if len(dimensions) == 0:
return ord(stream.read(1))
elif len(dimensions) == 1:
return [val for val in stream.read(dimensions[0])]
else:
res = []
for _ in range(dimensions[0]):
res.append(_read(dimensions[1:], stream))
return res
def extract(idx_filename):
"""
Extract information(image/labels) from idx file
Parameters
----------
idx_filename: str
Returns
-------
list of lists/unsigned char: `numpy.array`
image/label with shape designated in idx file
"""
with open(idx_filename, 'rb') as f:
magic_numbers=f.read(4)
print('magic_number', magic_numbers)
assert magic_numbers[0] == 0 and magic_numbers[1] == 0
if magic_numbers[2] != 8:
raise AssertionError('Only support for unsigned char now')
shape=magic_numbers[3]
print('shape', shape)
num_examples=int.from_bytes(f.read(4), byteorder='big')
print('number of examples',num_examples)
dimensions=[]
for _ in range(shape-1):
dimensions.append(int.from_bytes(f.read(4), byteorder='big'))
print('dimensions', dimensions)
data_list=[]
for _ in range(num_examples):
each_data_point=_read(dimensions, f)
print('each data point', np.array(each_data_point).shape)
data_list.append(each_data_point)
data_list = np.array(data_list)
return data_list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment