Skip to content

Instantly share code, notes, and snippets.

@knsong
Last active October 15, 2016 04:41
Show Gist options
  • Save knsong/082f449e98ce168d85c050b15e8b2e88 to your computer and use it in GitHub Desktop.
Save knsong/082f449e98ce168d85c050b15e8b2e88 to your computer and use it in GitHub Desktop.
regression accuracy layer for caffe
#include "caffe/layers/regression_accuracy_layer.hpp"
namespace caffe{
template <typename Dtype>
void RegressionAccuracyLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top){
vector<int> top_shape(0); // Accuracy is a scalar; 0 axes.
top[0]->Reshape(top_shape);
}
template <typename Dtype>
void RegressionAccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top){
CHECK_EQ(bottom[0]->count(1), 1) << "RegressionAccuracyLayer only support 1-D now!";
CHECK_EQ(bottom[0]->count(), bottom[1]->count()) << "Number of labels must match number of predictions; ";
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* bottom_label = bottom[1]->cpu_data();
int num = bottom[0]->shape(0);
LOG(INFO) << "total: " << num;
Dtype accuracy = 0;
for (int i = 0; i < num; ++i) {
const int predict_label_value = static_cast<int>(bottom_data[i] + 0.5); //rounding
const int label_value = static_cast<int>(bottom_label[i]);
LOG(INFO) << "Predicted_Value: " << predict_label_value << " Actual Label : " << label_value;
if (predict_label_value == label_value)
++accuracy;
}
// LOG(INFO) << "Accuracy: " << accuracy;
top[0]->mutable_cpu_data()[0] = accuracy / num;
}
INSTANTIATE_CLASS(RegressionAccuracyLayer);
REGISTER_LAYER_CLASS(RegressionAccuracy);
}//namespace caffe
#ifndef REGRESSION_ACCURACY_LAYER_HPP_
#define REGRESSION_ACCURACY_LAYER_HPP_
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe{
template <typename Dtype>
class RegressionAccuracyLayer : public Layer<Dtype>{
public:
explicit RegressionAccuracyLayer(const LayerParameter& param): Layer<Dtype>(param){}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {}
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual inline const char* type() const { return "RegressionAccuracy"; }
virtual inline int ExactNumBottomBlobs() const { return 2; }
virtual inline int MinTopBlobs() const { return 1; }
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
NOT_IMPLEMENTED;
}
};
}//namespace caffe
#endif//#ifndef REGRESSION_ACCURACY_LAYER_HPP_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment