Skip to content

Instantly share code, notes, and snippets.

@valadhi
Created July 30, 2014 12:03
Show Gist options
  • Save valadhi/fcccb671c0ab6f986cfa to your computer and use it in GitHub Desktop.
Save valadhi/fcccb671c0ab6f986cfa to your computer and use it in GitHub Desktop.
def rbmEmotions(big=False, reconstructRandom=False):
#data, labels = readMultiPIE(big, equalize=args.equalize)
data, labels = readother.read()
print "data.shape"
print data.shape
data = data / 255.0
labels = labels / 255.0
if args.relu:
activationFunction = Rectified()
data = scale(data)
else:
activationFunction = Sigmoid()
#trainData = data[0:-1, :]
Data = np.concatenate((data, labels), axis=1)
trainData = Data[0:-1, :]
print "trainData",trainData.shape
# Train the network
if args.train:
# The number of hidden units is taken from a deep learning tutorial
# The data are the values of the images have to be normalized before being
# presented to the network
nrVisible = len(data[0])
nrHidden = 800
# use 1 dropout to test the rbm for now
net = rbm.RBM(nrVisible, nrHidden, 1.2, 1, 1,
visibleActivationFunction=activationFunction,
hiddenActivationFunction=activationFunction,
rmsprop=args.rbmrmsprop,
nesterov=args.rbmnesterov,
sparsityConstraint=args.sparsity,
sparsityRegularization=0.5,
trainingEpochs=args.maxEpochs,
sparsityTraget=0.01)
net.train(trainData)
print net.weights.T.shape
t = visualizeWeights(net.weights.T, SMALL_SIZE, (10,10))
else:
# Take the saved network and use that for reconstructions
f = open(args.netFile, "rb")
t = pickle.load(f)
net = pickle.load(f)
f.close()
# get a random image and see it looks like
# if reconstructRandom:
# test = np.random.random_sample(test.shape)
# Show the initial image first
test = Data[-1, :]
print "test.shape"
print test.shape
plt.imshow(vectorToImage(test, SMALL_SIZE), cmap=plt.cm.gray)
plt.axis('off')
plt.savefig('initialface.png', transparent=True)
recon = net.reconstruct(test.reshape(1, test.shape[0]))
print recon.shape
plt.imshow(vectorToImage(recon, SMALL_SIZE), cmap=plt.cm.gray)
plt.axis('off')
plt.savefig('reconstructface.png', transparent=True)
# Show the weights and their form in a tile fashion
# Plot the weights
plt.imshow(t, cmap=plt.cm.gray)
plt.axis('off')
if args.rbmrmsprop:
st='rmsprop'
else:
st = 'simple'
plt.savefig('weights' + st + '.png', transparent=True)
# let's make some sparsity checks
hidden = net.hiddenRepresentation(test.reshape(1, test.shape[0]))
print hidden.sum()
print "done"
if args.save:
f = open(args.netFile, "wb")
pickle.dump(t, f)
pickle.dump(net, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment