Skip to content

Instantly share code, notes, and snippets.

Created June 24, 2015 13:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/53d7cb44c072ae6320ff to your computer and use it in GitHub Desktop.
Save anonymous/53d7cb44c072ae6320ff to your computer and use it in GitHub Desktop.
caffe patch for backwardFromToF
diff --git a/Makefile b/Makefile
index e4e66df..140acba 100644
--- a/Makefile
+++ b/Makefile
@@ -171,7 +171,7 @@ ifneq ($(CPU_ONLY), 1)
endif
LIBRARIES += glog gflags protobuf leveldb snappy \
lmdb boost_system hdf5_hl hdf5 m \
- opencv_core opencv_highgui opencv_imgproc
+ opencv_core opencv_highgui opencv_imgproc opencv_imgcodecs
PYTHON_LIBRARIES := boost_python python2.7
WARNINGS := -Wall -Wno-sign-compare
diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp
index 5665df1..b849f3e 100644
--- a/include/caffe/net.hpp
+++ b/include/caffe/net.hpp
@@ -64,6 +64,7 @@ class Net {
*/
void Backward();
void BackwardFromTo(int start, int end);
+ void BackwardFromToF(int start, int end);
void BackwardFrom(int start);
void BackwardTo(int end);
diff --git a/matlab/+caffe/Net.m b/matlab/+caffe/Net.m
index e6295bb..8f13b66 100644
--- a/matlab/+caffe/Net.m
+++ b/matlab/+caffe/Net.m
@@ -87,6 +87,12 @@ classdef Net < handle
function backward_prefilled(self)
caffe_('net_backward', self.hNet_self);
end
+ function backward_from_to(self, from, to)
+ caffe_('net_backward_from_to', self.hNet_self, from, to);
+ end
+ function backward_from_to_f(self, from, to)
+ caffe_('net_backward_from_to_f', self.hNet_self, from, to);
+ end
function res = forward(self, input_data)
CHECK(iscell(input_data), 'input_data must be a cell array');
CHECK(length(input_data) == length(self.inputs), ...
diff --git a/matlab/+caffe/private/caffe_.cpp b/matlab/+caffe/private/caffe_.cpp
index 4e0ebc1..30e7f83 100644
--- a/matlab/+caffe/private/caffe_.cpp
+++ b/matlab/+caffe/private/caffe_.cpp
@@ -301,6 +301,16 @@ static void net_forward(MEX_ARGS) {
net->ForwardPrefilled();
}
+// Usage: caffe_('net_forward', hNet, from, to)
+static void net_forward_from_to(MEX_ARGS) {
+ mxCHECK(nrhs == 3 && mxIsStruct(prhs[0]),
+ "Usage: caffe_('net_forward', hNet, from, to)");
+ Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
+ double from = mxGetScalar(prhs[1]);
+ double to = mxGetScalar(prhs[2]);
+ net->ForwardFromTo(from, to);
+}
+
// Usage: caffe_('net_backward', hNet)
static void net_backward(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -309,6 +319,25 @@ static void net_backward(MEX_ARGS) {
net->Backward();
}
+// Usage: caffe_('net_backward', hNet)
+static void net_backward_from_to(MEX_ARGS) {
+ mxCHECK(nrhs == 3 && mxIsStruct(prhs[0]),
+ "Usage: caffe_('net_backward_from_to', hNet, from, to)");
+ Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
+ double from = mxGetScalar(prhs[1]);
+ double to = mxGetScalar(prhs[2]);
+ net->BackwardFromTo(from, to);
+}
+// Usage: caffe_('net_backward', hNet)
+static void net_backward_from_to_f(MEX_ARGS) {
+ mxCHECK(nrhs == 3 && mxIsStruct(prhs[0]),
+ "Usage: caffe_('net_backward_from_to_f', hNet, from, to)");
+ Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
+ double from = mxGetScalar(prhs[1]);
+ double to = mxGetScalar(prhs[2]);
+ net->BackwardFromToF(from, to);
+}
+
// Usage: caffe_('net_copy_from', hNet, weights_file)
static void net_copy_from(MEX_ARGS) {
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
@@ -497,7 +526,10 @@ static handler_registry handlers[] = {
{ "get_net", get_net },
{ "net_get_attr", net_get_attr },
{ "net_forward", net_forward },
+ { "net_forward_from_to", net_forward_from_to },
{ "net_backward", net_backward },
+ { "net_backward_from_to", net_backward_from_to },
+ { "net_backward_from_to_f", net_backward_from_to_f },
{ "net_copy_from", net_copy_from },
{ "net_reshape", net_reshape },
{ "net_save", net_save },
diff --git a/models/bvlc_googlenet/deploy.prototxt b/models/bvlc_googlenet/deploy.prototxt
index 4648bf2..54923a1 100644
--- a/models/bvlc_googlenet/deploy.prototxt
+++ b/models/bvlc_googlenet/deploy.prototxt
@@ -1,4 +1,5 @@
name: "GoogleNet"
+force_backward: true
input: "data"
input_dim: 10
input_dim: 3
diff --git a/models/bvlc_reference_caffenet/deploy.prototxt b/models/bvlc_reference_caffenet/deploy.prototxt
index 29ccf14..072ec90 100644
--- a/models/bvlc_reference_caffenet/deploy.prototxt
+++ b/models/bvlc_reference_caffenet/deploy.prototxt
@@ -1,4 +1,5 @@
name: "CaffeNet"
+force_backward: true
input: "data"
input_dim: 10
input_dim: 3
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index a18ee63..3360d2f 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -577,6 +577,20 @@ void Net<Dtype>::BackwardFromTo(int start, int end) {
}
template <typename Dtype>
+void Net<Dtype>::BackwardFromToF(int start, int end) {
+ CHECK_GE(end, 0);
+ CHECK_LT(start, layers_.size());
+ for (int i = start; i >= end; --i) {
+ for (int j = 0; j < bottom_need_backward_[i].size(); ++j) {
+ bottom_need_backward_[i][j] = true;
+ }
+ layers_[i]->Backward(
+ top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]);
+ if (debug_info_) { BackwardDebugInfo(i); }
+ }
+}
+
+template <typename Dtype>
void Net<Dtype>::InputDebugInfo(const int input_id) {
const Blob<Dtype>& blob = *net_input_blobs_[input_id];
const string& blob_name = blob_names_[net_input_blob_indices_[input_id]];
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment