Skip to content

Instantly share code, notes, and snippets.

@perimosocordiae
Created March 23, 2010 06:55
Show Gist options
  • Save perimosocordiae/340904 to your computer and use it in GitHub Desktop.
Save perimosocordiae/340904 to your computer and use it in GitHub Desktop.
PCA-based face recognition
#!/usr/bin/env python
from sys import argv, exit
from glob import glob
from matplotlib.pyplot import show, imshow, subplot, gray
from matplotlib.image import imread
from numpy import zeros,dot,sum,arange,savez,load,flipud,array
from scipy.linalg import eig
ATT = False
CLIP = False
if ATT:
W = 92
H = 112
def read(fname): return imread(fname).flatten()
elif CLIP: # LFW clipped
W = 150
H = 150
def read(fname): return imread(fname)[50:200,50:200].flatten()
else:
W = 250
H = 250
def read(fname): return imread(fname).flatten()
def disp(raster): imshow(flipud(raster.reshape((H,W))))
def project(img,mean,basis,clo,chi): return dot(img-mean,basis[:,clo:chi])
def compute_vectors(face_paths,num_trains):
train = zeros((num_trains,W*H))
i = 0
for subj in face_paths:
for fname in (s for s in subj[2:] if s):
train[i] = read(fname)
i += 1
print "Read training set"
mean_img = train.mean(0)
for img in train:
img -= mean_img
print "Subtracted mean"
tt = train.transpose()
vals,vecs = eig(dot(train,tt))
return dot(tt,vecs[vals.argsort()[::-1]]),mean_img
def make_sim_matrix(target_set,query_set):
sim = zeros((len(target_set),len(query_set)))
for ti,t in enumerate(target_set):
for qi,q in enumerate(query_set):
sim[ti,qi] = sum((t-q)**2)
return sim
def score(sim):
tot = sim.shape[0]
return (tot - len((sim.argmin(0) - arange(tot)).nonzero()[0]))/float(tot)
def mean_mtx(face_paths): # Note: assumes non-jagged face_paths (won't work for lfw)
mat = zeros((len(face_paths),len(face_paths[0])-1,W*H))
for s,subj in enumerate(face_paths):
for f,fname in enumerate(subj[1:]):
mat[s,f] = read(fname)
return mat.mean(1)
def mean_mtx_jagged(face_paths):
mat = zeros((len(face_paths),W*H))
for s,subj in enumerate(face_paths):
faces = array([read(f) for f in subj[1:] if f])
mat[s] = faces.mean(0)
return mat
def compare_imgs(face_paths,qi,ti):
gray()
subplot(121)
imshow(flipud(imread(face_paths[ti,0])))
subplot(122)
imshow(flipud(imread(face_paths[qi,1])))
show()
if __name__ == "__main__":
if len(argv) == 1:
faces_dir = "orl_faces" if ATT else "lfw2"
num_subs,num_faces = 100,10
face_paths = [glob("%s/*"%p)[:num_faces] for p in glob("%s/*"%faces_dir)[:num_subs]]
# numpy doesn't like jagged ndarrays
for i in range(len(face_paths)):
face_paths[i].extend([None]*(num_faces-len(face_paths[i])))
num_trains = reduce(lambda a,s: a+len(s[2:])-s.count(None),face_paths,0)
basis,mean = compute_vectors(face_paths,num_trains)
print "created basis"
savez("basis-mean.npz",basis=basis,mean=mean,face_paths=face_paths)
elif len(argv) == 2 and argv[1].split('.')[-1] == 'npz':
npz = load(argv[1])
basis,mean,face_paths = npz['basis'],npz['mean'],npz['face_paths']
else:
exit("Usage: %s [basis-mean.npz]"%argv[0])
st = 0
for i in [1,2]:#[1,2,3,4,5,6,7,8,9,10,14,15,16,17,18,30,100]:
target = [ project(read(s[0]),mean,basis,st,i+st) for s in face_paths ]
query = [ project(read(s[1]),mean,basis,st,i+st) for s in face_paths ]
sim = make_sim_matrix(target,query)
#print "%d:%d"%(st,st+i),score(sim)
print i,score(sim)
mquery = [ read(s[0]) for s in face_paths ]
if ATT:
mmat = mean_mtx(face_paths)
else:
mmat = mean_mtx_jagged(face_paths)
msim = make_sim_matrix(mmat,mquery)
print "mean-based alg:",score(msim)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment