Skip to content

Instantly share code, notes, and snippets.

@asanakoy
Created December 18, 2015 22:54
Show Gist options
  • Save asanakoy/ba70c3eeff6da26d68d4 to your computer and use it in GitHub Desktop.
Save asanakoy/ba70c3eeff6da26d68d4 to your computer and use it in GitHub Desktop.
Script to show non-deterministic caffe behaviour
from caffeext import *
import caffe
import os.path
from scipy import misc
import time
def run_test_on_images(net, transformer, data, batch_size):
data_blob_shape = net.blobs['data'].data.shape
data_blob_shape = list(data_blob_shape)
net.blobs['data'].reshape(batch_size, data_blob_shape[1], data_blob_shape[2], data_blob_shape[3])
k = 0
right_answers = np.array([], dtype=int)
images = []
labels = np.array([], dtype=int)
# TEST
results = []
img_pathes = []
prob = []
for img_path, label in data:
image = misc.imread(img_path) # TEST
img_pathes.append(img_path)
images.append(image)
labels = np.append(labels, label)
# print k
k += 1
if k % batch_size == 0:
net.blobs['data'].data[...] = map(lambda x: transformer.preprocess('data', x), images)
# process the data through network
out = net.forward()
predict = np.argmax(out['prob'], axis=1)
prob.extend(out['prob'])
results.extend(zip(img_pathes, predict))
right_answers = np.append(right_answers, (predict == labels))
images = []
img_pathes = []
labels = np.array([])
if k % batch_size:
net.blobs['data'].reshape(k % batch_size, data_blob_shape[1], data_blob_shape[2], data_blob_shape[3])
net.blobs['data'].data[...] = map(lambda x: transformer.preprocess('data', x), images)
out = net.forward()
predict = np.argmax(out['prob'], axis=1)
prob.extend(out['prob'])
results.extend(zip(img_pathes, predict))
right_answers = np.append(right_answers, (predict == labels))
accuracy = np.sum(right_answers) / (1.0 * len(right_answers))
# print 'Right/total: {}/{}'.format(np.sum(right_answers), (1.0 * len(right_answers)))
# print out['prob'][0][0]
return accuracy, results, prob
########################################################################################################################
def test_network(network_root_path, snapshot_iteration, sample):
"""
:param network_root_path: network root folder
:param snapshot_iteration: iteration to get snapshot from
:param test_type: type of test procedure
"""
network_root_path = os.path.expanduser(network_root_path)
net = caffe.Net(os.path.join(network_root_path, "model/net_config/deploy.prototxt"),
os.path.join(network_root_path, "model/snap_iter_{}.caffemodel".format(snapshot_iteration)), caffe.TEST)
caffe.set_mode_gpu()
# caffe.set_device(0)
mean_path = os.path.join(network_root_path, "train.leveldb/mean.binaryproto")
mean = protomean2array(mean_path)
# transformer transforms image from RGB HxWxC -> BGR CxHxW and subtracts the mean
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2, 0, 1)) # height*width*channel -> channel*height*width
transformer.set_mean('data', mean) # subtract mean
transformer.set_raw_scale('data', 1) # pixel value scaling
transformer.set_channel_swap('data', (2, 1, 0)) # RGB -> BGR
# acc, res_leveldb, prob_ref = run_test_on_images(net, transformer, [sample, sample, sample], 2)
# for j in xrange(len(prob_ref)):
# print prob_ref[j][0:4]
# print acc
for i in xrange(3):
accuracy, res_leveldb, probabilites = run_test_on_images(net, transformer, [sample, sample, sample], batch_size=i+1)
# print np.all(prob[0] == prob_ref[0])
for j in xrange(len(probabilites)):
print probabilites[j][0:4]
print 'Accuracy:', accuracy
print '--'
def main():
network_root_path = '~/workspace/meisterwerke/cnn/51_100_test'
snapshot_iteration = 210000
sample = ('/export/home/asanakoy/workspace/meisterwerke/crops_227x227_step30/mwm35259_0_flipped.png', 0)
print 'Testing network {} on {} iter'.format(network_root_path, snapshot_iteration)
start_time = time.clock()
test_network(network_root_path, snapshot_iteration, sample)
print 'Elapsed time: {} s'.format(time.clock() - start_time)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment