Created
June 23, 2019 19:59
-
-
Save jeremyfix/c0c60ee94e0d30753efbafa30b3cc49c to your computer and use it in GitHub Desktop.
Downloading and processing AT&T face dataset in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# Downloading and processing the AT&T face dataset | |
import requests | |
import tempfile | |
import os | |
import glob | |
import shutil | |
import pathlib | |
import numpy | |
import sys | |
import numpy as np | |
# https://stackoverflow.com/a/35726744/2164582 | |
def read_pgm(pgmf): | |
"""Return a raster of integers from a PGM as a list of lists.""" | |
assert(pgmf.readline() == b'P5\n') | |
(width, height) = [int(i) for i in pgmf.readline().split()] | |
depth = int(pgmf.readline()) | |
assert depth <= 255 | |
raster = [] | |
for y in range(height * width): | |
raster.append(ord(pgmf.read(1))) | |
return raster | |
url = "http://www.cl.cam.ac.uk/Research/DTG/attarchive/pub/data/att_faces.tar.Z" | |
# Download the dataset | |
print("Downloading the dataeset") | |
r = requests.get(url) | |
gz_data = r.content | |
# Save the downloaded data to disk | |
# Do not know how to untar from memory | |
f = open('att_faces.tar.Z', 'wb') | |
f.write(r.content) | |
f.close() | |
# | |
tmpdir = tempfile.mkdtemp() | |
print("Uncompressing the dataset") | |
os.system("mv att_faces.tar.Z {}".format(tmpdir)) | |
os.system('cd {}; tar -zxvf att_faces.tar.Z > /dev/null'.format(tmpdir)) | |
print("Reading the images") | |
width = 92 | |
height = 112 | |
samples = np.zeros(400, dtype=[('input', np.uint8, width*height), ('label', int, 1)]) | |
idx = 0 | |
for img in pathlib.Path(tmpdir).glob('**/*.pgm'): | |
sys.stdout.write('\r Processing image n°{}/400 : {}'.format(idx+1, img)) | |
sys.stdout.flush() | |
data = read_pgm(open(img, 'rb')) | |
label = int(img.as_posix().split('/')[-2][1:]) | |
samples['input'][idx, :] = data | |
samples['label'][idx] = label | |
idx += 1 | |
print() | |
print("Removing the temporary files") | |
shutil.rmtree(tmpdir) | |
np.save('att_faces.npy', samples) | |
#################################################################### | |
# Display some samples | |
import matplotlib.pyplot as plt | |
import random | |
fig, axes = plt.subplots(1,10, figsize=(10,2)) | |
for i in range(10): | |
idx = random.randint(0, 399) | |
axes[i].imshow(samples['input'][idx].reshape((height, width)) , cmap='gray') | |
axes[i].set_xticks([]) | |
axes[i].set_yticks([]) | |
axes[i].set_title('I\'m a {}'.format(samples['label'][idx])) | |
plt.tight_layout() | |
plt.savefig('faces.png', bbox_inches='tight') | |
plt.show() |
Author
jeremyfix
commented
Jun 23, 2019
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment