Last active
February 27, 2017 12:07
-
-
Save lukeyeager/5ae183044d96ef1540f2 to your computer and use it in GitHub Desktop.
Caffe - Rewrite Accuracy layer as a Python layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import caffe | |
import json | |
class AccuracyLayer(caffe.Layer): | |
""" | |
Rewrite Accuracy layer as a Python layer | |
Accepts JSON-encoded parameters through param_str | |
Use like this: | |
layer { | |
name: "accuracy" | |
type: "Python" | |
bottom: "pred" | |
bottom: "label" | |
top: "accuracy" | |
include { | |
phase: TEST | |
} | |
python_param { | |
module: "accuracy_layer" | |
layer: "AccuracyLayer" | |
param_str: "{\"top_k\": 2}" | |
} | |
} | |
""" | |
def setup(self, bottom, top): | |
assert len(bottom) == 2, 'requires two layer.bottoms' | |
assert len(top) == 1, 'requires a single layer.top' | |
if hasattr(self, 'param_str') and self.param_str: | |
params = json.loads(self.param_str) | |
else: | |
params = {} | |
self.top_k = params.get('top_k', 1) | |
def reshape(self, bottom, top): | |
top[0].reshape(1) | |
def forward(self, bottom, top): | |
# Renaming for clarity | |
predictions = bottom[0].data | |
ground_truth = bottom[1].data | |
num_correct = 0.0 | |
# NumPy magic - get top K predictions for each datum | |
top_predictions = (-predictions).argsort()[:, :self.top_k] | |
for batch_index, predictions in enumerate(top_predictions): | |
if ground_truth[batch_index] in predictions: | |
num_correct += 1 | |
# Accuracy is averaged over the batch | |
top[0].data[0] = num_correct / len(ground_truth) | |
def backward(self, top, propagate_down, bottom): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
What happens during the test phase,
Will it get averaged over all the iterations?