Skip to content

Instantly share code, notes, and snippets.

@solrex
Created July 30, 2017 09:44
Show Gist options
  • Save solrex/78c6e9ceae3d469d9b2cd7cb1b0e2d06 to your computer and use it in GitHub Desktop.
Save solrex/78c6e9ceae3d469d9b2cd7cb1b0e2d06 to your computer and use it in GitHub Desktop.
[caffe-mobile] Transpose layer patch
diff --git a/include/caffe/layers/transpose_layer.hpp b/include/caffe/layers/transpose_layer.hpp
new file mode 100644
index 0000000..aedb1ff
--- /dev/null
+++ b/include/caffe/layers/transpose_layer.hpp
@@ -0,0 +1,45 @@
+#ifndef TRANSPOSE_LAYER_HPP_
+#define TRANSPOSE_LAYER_HPP_
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename Dtype>
+class TransposeLayer : public Layer<Dtype> {
+ public:
+ explicit TransposeLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual inline const char* type() const { return "Transpose"; }
+ virtual inline int ExactNumBottomBlobs() const { return 1; }
+ virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ private:
+ TransposeParameter transpose_param_;
+ vector<int> permute(const vector<int>& vec);
+ Blob<int> bottom_counts_;
+ Blob<int> top_counts_;
+ Blob<int> forward_map_;
+ Blob<int> backward_map_;
+ Blob<int> buf_;
+};
+
+} // namespace caffe
+
+#endif // TRANSPOSE_LAYER_HPP_
diff --git a/src/caffe/layers/transpose_layer.cpp b/src/caffe/layers/transpose_layer.cpp
new file mode 100644
index 0000000..a55112d
--- /dev/null
+++ b/src/caffe/layers/transpose_layer.cpp
@@ -0,0 +1,115 @@
+#include "caffe/layers/transpose_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void transpose_cpu(const int count, const Dtype* from_data, Dtype* to_data,
+ const int* from_counts, const int* to_counts, const int* map, const int num_axes) {
+ int from_inds[kMaxBlobAxes] = {0};
+ for (int index = 0; index < count; index++) {
+ int from_index = index, to_index = 0;
+ for (int i = 0; i < num_axes; i++) {
+ from_inds[i] = from_index / from_counts[i];
+ from_index = from_index % from_counts[i];
+ }
+ for (int i = 0; i < num_axes; i++) {
+ to_index += from_inds[map[i]] * to_counts[i];
+ }
+
+ *(to_data+to_index) = *(from_data+index);
+ }
+}
+
+template <typename Dtype>
+void TransposeLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ CHECK_NE(bottom[0], top[0]) << this->type() << " Layer does not support "
+ "in-place computation.";
+ transpose_param_ = this->layer_param_.transpose_param();
+}
+
+template <typename Dtype>
+void TransposeLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ vector<int> shape = bottom[0]->shape();
+ CHECK_GT(shape.size(), 0) << "the dimension of the transposed blob should "
+ "be greater than 0.";
+ CHECK_LE(shape.size(), kMaxBlobAxes) << "the dimension of the transposed blob should "
+ "be less than kMaxBlobAxes (" << kMaxBlobAxes << ").";
+ CHECK_EQ(shape.size(), transpose_param_.dim_size()) << "the dimensions of "
+ "the top blob and bottom blob must be equal.";
+ vector<int> top_shape = permute(shape);
+ top[0]->Reshape(top_shape);
+
+ const int num_axes = transpose_param_.dim_size();
+ shape.clear();
+ shape.push_back(num_axes);
+
+ bottom_counts_.Reshape(shape);
+ top_counts_.Reshape(shape);
+
+ int* bottom_counts_data=bottom_counts_.mutable_cpu_data();
+ int* top_counts_data = top_counts_.mutable_cpu_data();
+ for (int i = 1; i < num_axes; i++) {
+ *bottom_counts_data = bottom[0]->count(i);
+ *top_counts_data = top[0]->count(i);
+ bottom_counts_data++;
+ top_counts_data++;
+ }
+ *bottom_counts_data = 1;
+ *top_counts_data = 1;
+
+ forward_map_.Reshape(shape);
+ backward_map_.Reshape(shape);
+
+ int* forward_map_data=forward_map_.mutable_cpu_data();
+ int* backward_map_data=backward_map_.mutable_cpu_data();
+ for (int i = 0; i < num_axes; i++) {
+ *forward_map_data = transpose_param_.dim(i);
+ backward_map_data[transpose_param_.dim(i)] = i;
+ forward_map_data++;
+ }
+
+ shape.clear();
+ shape.push_back(bottom[0]->count() * num_axes);
+ buf_.Reshape(shape);
+
+}
+
+template <typename Dtype>
+vector<int> TransposeLayer<Dtype>::permute(const vector<int>& vec) {
+ vector<int> new_vec(vec.size());
+ for (int i = 0; i < vec.size(); i++) {
+ new_vec[i] = vec[transpose_param_.dim(i)];
+ }
+ return new_vec;
+}
+
+
+template <typename Dtype>
+void TransposeLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ transpose_cpu<Dtype>(bottom[0]->count(), bottom[0]->cpu_data(), top[0]->mutable_cpu_data(),
+ bottom_counts_.cpu_data(), top_counts_.cpu_data(), forward_map_.cpu_data(),
+ bottom[0]->shape().size());
+}
+
+template <typename Dtype>
+void TransposeLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ if (!propagate_down[0]) {
+ return;
+ }
+ transpose_cpu<Dtype>(bottom[0]->count(), top[0]->cpu_diff(), bottom[0]->mutable_cpu_diff(),
+ top_counts_.cpu_data(), bottom_counts_.cpu_data(), backward_map_.cpu_data(),
+ bottom[0]->shape().size());
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(TransposeLayer);
+#endif
+
+INSTANTIATE_CLASS(TransposeLayer);
+REGISTER_LAYER_CLASS(Transpose);
+
+} // namespace caffe
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index aa77dae..89df7fa 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -407,6 +407,7 @@ message LayerParameter {
optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138;
optional WindowDataParameter window_data_param = 129;
+ optional TransposeParameter transpose_param = 147;
}
// Message that stores parameters used to apply transformation
@@ -1412,3 +1413,11 @@ message PReLUParameter {
// Whether or not slope parameters are shared across channels.
optional bool channel_shared = 2 [default = false];
}
+
+message TransposeParameter {
+ // For example, if you want to transpose NxCxHxW into WxNxHxC,
+ // the parameter should be the following:
+ // transpose_param { dim: 3 dim: 0 dim: 2 dim: 1 }
+ // ie, if the i-th dim has value n, then the i-th axis of top is equal to the n-th axis of bottom.
+ repeated int32 dim=1;
+}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment