Created
March 1, 2015 06:48
-
-
Save avdmitry/bbddbc0be6ad564114fa to your computer and use it in GitHub Desktop.
data_transformer.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <opencv2/core/core.hpp> | |
#include <opencv2/imgproc/imgproc.hpp> | |
#include <string> | |
#include <vector> | |
#include "caffe/data_transformer.hpp" | |
#include "caffe/util/io.hpp" | |
#include "caffe/util/math_functions.hpp" | |
#include "caffe/util/rng.hpp" | |
namespace caffe { | |
template<typename Dtype> | |
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, | |
Phase phase) | |
: param_(param), phase_(phase) { | |
// check if we want to use mean_file | |
if (param_.has_mean_file()) { | |
CHECK_EQ(param_.mean_value_size(), 0) << | |
"Cannot specify mean_file and mean_value at the same time"; | |
const string& mean_file = param.mean_file(); | |
LOG(INFO) << "Loading mean file from: " << mean_file; | |
BlobProto blob_proto; | |
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); | |
data_mean_.FromProto(blob_proto); | |
} | |
// check if we want to use mean_value | |
if (param_.mean_value_size() > 0) { | |
CHECK(param_.has_mean_file() == false) << | |
"Cannot specify mean_file and mean_value at the same time"; | |
for (int c = 0; c < param_.mean_value_size(); ++c) { | |
mean_values_.push_back(param_.mean_value(c)); | |
} | |
} | |
} | |
template <typename Dtype> | |
cv::Mat DatumToCVMat(const Datum& datum) { | |
const string& data = datum.data(); | |
const int datum_channels = datum.channels(); | |
const int datum_height = datum.height(); | |
const int datum_width = datum.width(); | |
const bool has_uint8 = data.size() > 0; | |
cv::Mat img = cv::Mat(datum_height, datum_width, | |
CV_MAKETYPE(cv::DataDepth<Dtype>::value, datum_channels)); | |
CHECK(img.depth() == CV_8U || img.depth() == CV_32F) << "unsupported image data type"; | |
Dtype datum_element; | |
int index; | |
for (int h = 0; h < datum_height; ++h) { | |
uchar* ptr = img.ptr<uchar>(h); | |
float* ptrf = img.ptr<float>(h); | |
int img_index = 0; | |
for (int w = 0; w < datum_width; ++w) { | |
for (int c = 0; c < datum_channels; ++c) { | |
index = (c * datum_height + h) * datum_width + w; | |
if (has_uint8) { | |
datum_element = | |
static_cast<Dtype>(static_cast<uint8_t>(data[index])); | |
} else { | |
datum_element = datum.float_data(index); | |
} | |
if (img.depth()==CV_8U) | |
{ | |
ptr[img_index++] = datum_element; | |
} else // CV_32F | |
{ | |
ptrf[img_index++] = datum_element; | |
} | |
} | |
} | |
} | |
return img; | |
} | |
template<typename Dtype> | |
void DataTransformer<Dtype>::Transform(const Datum& datum, | |
Dtype* transformed_data) { | |
const string& data = datum.data(); | |
const int datum_channels = datum.channels(); | |
const int datum_height = datum.height(); | |
const int datum_width = datum.width(); | |
const int crop_size = param_.crop_size(); | |
const Dtype scale = param_.scale(); | |
const bool do_mirror = param_.mirror() && Rand(2); | |
const bool has_mean_file = param_.has_mean_file(); | |
const bool has_uint8 = data.size() > 0; | |
const bool has_mean_values = mean_values_.size() > 0; | |
CHECK_GT(datum_channels, 0); | |
CHECK_GE(datum_height, crop_size); | |
CHECK_GE(datum_width, crop_size); | |
Dtype* mean = NULL; | |
if (has_mean_file) { | |
CHECK_EQ(datum_channels, data_mean_.channels()); | |
CHECK_EQ(datum_height, data_mean_.height()); | |
CHECK_EQ(datum_width, data_mean_.width()); | |
mean = data_mean_.mutable_cpu_data(); | |
} | |
if (has_mean_values) { | |
CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) << | |
"Specify either 1 mean_value or as many as channels: " << datum_channels; | |
if (datum_channels > 1 && mean_values_.size() == 1) { | |
// Replicate the mean_value for simplicity | |
for (int c = 1; c < datum_channels; ++c) { | |
mean_values_.push_back(mean_values_[0]); | |
} | |
} | |
} | |
int height = datum_height; | |
int width = datum_width; | |
int h_off = 0; | |
int w_off = 0; | |
if (crop_size) { | |
height = crop_size; | |
width = crop_size; | |
// We only do random crop when we do training. | |
if (phase_ == TRAIN) { | |
h_off = Rand(datum_height - crop_size + 1); | |
w_off = Rand(datum_width - crop_size + 1); | |
} else { | |
h_off = (datum_height - crop_size) / 2; | |
w_off = (datum_width - crop_size) / 2; | |
} | |
} | |
Dtype datum_element; | |
int top_index, data_index; | |
for (int c = 0; c < datum_channels; ++c) { | |
for (int h = 0; h < height; ++h) { | |
for (int w = 0; w < width; ++w) { | |
data_index = (c * datum_height + h_off + h) * datum_width + w_off + w; | |
if (do_mirror) { | |
top_index = (c * height + h) * width + (width - 1 - w); | |
} else { | |
top_index = (c * height + h) * width + w; | |
} | |
if (has_uint8) { | |
datum_element = | |
static_cast<Dtype>(static_cast<uint8_t>(data[data_index])); | |
} else { | |
datum_element = datum.float_data(data_index); | |
} | |
if (has_mean_file) { | |
transformed_data[top_index] = | |
(datum_element - mean[data_index]) * scale; | |
} else { | |
if (has_mean_values) { | |
transformed_data[top_index] = | |
(datum_element - mean_values_[c]) * scale; | |
} else { | |
transformed_data[top_index] = datum_element * scale; | |
} | |
} | |
} | |
} | |
} | |
} | |
template<typename Dtype> | |
void DataTransformer<Dtype>::Transform(const Datum& datum, | |
Blob<Dtype>* transformed_blob) { | |
const int datum_channels = datum.channels(); | |
const int datum_height = datum.height(); | |
const int datum_width = datum.width(); | |
const int channels = transformed_blob->channels(); | |
const int height = transformed_blob->height(); | |
const int width = transformed_blob->width(); | |
const int num = transformed_blob->num(); | |
CHECK_EQ(channels, datum_channels); | |
CHECK_LE(height, datum_height); | |
CHECK_LE(width, datum_width); | |
CHECK_GE(num, 1); | |
const int crop_size = param_.crop_size(); | |
if (crop_size) { | |
CHECK_EQ(crop_size, height); | |
CHECK_EQ(crop_size, width); | |
} else { | |
CHECK_EQ(datum_height, height); | |
CHECK_EQ(datum_width, width); | |
} | |
cv::Mat cv_img = DatumToCVMat<Dtype>(datum); | |
Transform(cv_img, transformed_blob); | |
} | |
template<typename Dtype> | |
void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector, | |
Blob<Dtype>* transformed_blob) { | |
const int datum_num = datum_vector.size(); | |
const int num = transformed_blob->num(); | |
const int channels = transformed_blob->channels(); | |
const int height = transformed_blob->height(); | |
const int width = transformed_blob->width(); | |
CHECK_GT(datum_num, 0) << "There is no datum to add"; | |
CHECK_LE(datum_num, num) << | |
"The size of datum_vector must be no greater than transformed_blob->num()"; | |
Blob<Dtype> uni_blob(1, channels, height, width); | |
for (int item_id = 0; item_id < datum_num; ++item_id) { | |
int offset = transformed_blob->offset(item_id); | |
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); | |
Transform(datum_vector[item_id], &uni_blob); | |
} | |
} | |
template<typename Dtype> | |
void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector, | |
Blob<Dtype>* transformed_blob) { | |
const int mat_num = mat_vector.size(); | |
const int num = transformed_blob->num(); | |
const int channels = transformed_blob->channels(); | |
const int height = transformed_blob->height(); | |
const int width = transformed_blob->width(); | |
CHECK_GT(mat_num, 0) << "There is no MAT to add"; | |
CHECK_EQ(mat_num, num) << | |
"The size of mat_vector must be equals to transformed_blob->num()"; | |
Blob<Dtype> uni_blob(1, channels, height, width); | |
for (int item_id = 0; item_id < mat_num; ++item_id) { | |
int offset = transformed_blob->offset(item_id); | |
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); | |
Transform(mat_vector[item_id], &uni_blob); | |
} | |
} | |
template<typename Dtype> | |
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img, | |
Blob<Dtype>* transformed_blob) { | |
const int img_channels = cv_img.channels(); | |
const int img_height = cv_img.rows; | |
const int img_width = cv_img.cols; | |
const int channels = transformed_blob->channels(); | |
const int height = transformed_blob->height(); | |
const int width = transformed_blob->width(); | |
const int num = transformed_blob->num(); | |
CHECK_EQ(channels, img_channels); | |
CHECK_GE(num, 1); | |
CHECK(cv_img.depth() == CV_8U || cv_img.depth() == CV_32F) << "unsupported image data type"; | |
const int crop_size = param_.crop_size(); | |
const Dtype scale = param_.scale(); | |
const bool do_mirror = param_.mirror() && Rand(2); | |
const bool has_mean_file = param_.has_mean_file(); | |
const bool has_mean_values = mean_values_.size() > 0; | |
const bool has_min_size = param_.has_min_size(); | |
const bool has_max_size = param_.has_max_size(); | |
const int max_size = param_.max_size(); | |
int min_size = param_.min_size(); | |
CHECK_GT(img_channels, 0); | |
Dtype* mean = NULL; | |
if (has_mean_file) { | |
CHECK_EQ(img_channels, data_mean_.channels()); | |
CHECK_EQ(img_height, data_mean_.height()); | |
CHECK_EQ(img_width, data_mean_.width()); | |
mean = data_mean_.mutable_cpu_data(); | |
} | |
if (has_mean_values) { | |
CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) << | |
"Specify either 1 mean_value or as many as channels: " << img_channels; | |
if (img_channels > 1 && mean_values_.size() == 1) { | |
// Replicate the mean_value for simplicity | |
for (int c = 1; c < img_channels; ++c) { | |
mean_values_.push_back(mean_values_[0]); | |
} | |
} | |
} | |
// if min_size is not set, set to crop_size | |
if (has_min_size) { | |
CHECK_GE(min_size, crop_size) << "Minimum image size must be >= crop_size"; | |
} else { | |
min_size = crop_size; | |
} | |
cv::Mat cv_resized_img = cv_img; | |
int img_resized_width = img_width; | |
int img_resized_height = img_height; | |
int resize_size = min_size; | |
// We do scale augmentation at TRAIN time only | |
// Handle 'scale' augmentation by randomly choosing a minimum image | |
// size to be cropped from, between [min_size, max_size], if | |
// max_size and min_size are set. | |
if (phase_ == TRAIN && has_max_size) { | |
CHECK_GE(max_size, min_size); | |
resize_size = min_size + Rand(max_size - min_size); | |
} | |
// if either of the sides of the image are less than crop size, | |
// enlarge to a given size_size >= crop_size | |
if (img_height < img_width && img_height < resize_size) { | |
// resize height so it is crop_size | |
img_resized_height = resize_size; | |
img_resized_width = | |
ceil(static_cast<float>(resize_size*img_width)/img_height); | |
} else if (img_width < resize_size) { | |
// resize width so it is crop_size | |
img_resized_width = resize_size; | |
img_resized_height = | |
ceil(static_cast<float>(resize_size*img_height)/img_width); | |
} | |
if (img_width != img_resized_width | |
|| img_height != img_resized_height) { | |
// In practice Andrew Howard used CUBIC for both upsampling | |
// and downsampling, Googlenet used random assortment of methods. | |
CHECK_EQ(img_channels, 3) | |
<< "Currently resizing of input is only supported for 3 channel inputs"; | |
CHECK(!has_mean_file) << "Currently resizing of input is only supported for mean values"; | |
// and downsampling, Googlenet used random assortment of methods. | |
cv::resize(cv_img, cv_resized_img, | |
cv::Size(img_resized_width, img_resized_height), cv::INTER_CUBIC); | |
} | |
int h_off = 0; | |
int w_off = 0; | |
cv::Mat cv_cropped_img = cv_resized_img; | |
if (crop_size) { | |
CHECK_EQ(crop_size, height); | |
CHECK_EQ(crop_size, width); | |
// We only do random crop when we do training. | |
if (phase_ == TRAIN) { | |
h_off = Rand(img_resized_height - crop_size + 1); | |
w_off = Rand(img_resized_width - crop_size + 1); | |
} else { | |
h_off = (img_resized_height - crop_size) / 2; | |
w_off = (img_resized_width - crop_size) / 2; | |
} | |
cv::Rect roi(w_off, h_off, crop_size, crop_size); | |
cv_cropped_img = cv_resized_img(roi); | |
} else { | |
CHECK_EQ(img_resized_height, height); | |
CHECK_EQ(img_resized_width, width); | |
} | |
CHECK(cv_cropped_img.data); | |
Dtype* transformed_data = transformed_blob->mutable_cpu_data(); | |
int top_index; | |
for (int h = 0; h < height; ++h) { | |
const uchar* ptr = cv_cropped_img.ptr<uchar>(h); | |
const float* ptrf = cv_cropped_img.ptr<float>(h); | |
int img_index = 0; | |
for (int w = 0; w < width; ++w) { | |
for (int c = 0; c < img_channels; ++c) { | |
if (do_mirror) { | |
top_index = (c * height + h) * width + (width - 1 - w); | |
} else { | |
top_index = (c * height + h) * width + w; | |
} | |
Dtype pixel; | |
if (cv_cropped_img.depth()==CV_8U) | |
{ | |
pixel = static_cast<Dtype>(ptr[img_index++]); | |
} else // CV_32F | |
{ | |
pixel = static_cast<Dtype>(ptrf[img_index++]); | |
} | |
if (has_mean_file) { | |
int mean_index = (c * img_resized_height + h_off + h) | |
* img_resized_width + w_off + w; | |
transformed_data[top_index] = | |
(pixel - mean[mean_index]) * scale; | |
} else { | |
if (has_mean_values) { | |
transformed_data[top_index] = | |
(pixel - mean_values_[c]) * scale; | |
} else { | |
transformed_data[top_index] = pixel * scale; | |
} | |
} | |
} | |
} | |
} | |
} | |
template<typename Dtype> | |
void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob, | |
Blob<Dtype>* transformed_blob) { | |
cv::Mat cv_img = BlobToCVMat<Dtype, 3>(input_blob); | |
Transform(cv_img, transformed_blob); | |
} | |
template <typename Dtype> | |
void DataTransformer<Dtype>::InitRand() { | |
const bool needs_rand = param_.mirror() || | |
(phase_ == TRAIN && param_.crop_size()); | |
if (needs_rand) { | |
const unsigned int rng_seed = caffe_rng_rand(); | |
rng_.reset(new Caffe::RNG(rng_seed)); | |
} else { | |
rng_.reset(); | |
} | |
} | |
template <typename Dtype> | |
int DataTransformer<Dtype>::Rand(int n) { | |
CHECK(rng_); | |
CHECK_GT(n, 0); | |
caffe::rng_t* rng = | |
static_cast<caffe::rng_t*>(rng_->generator()); | |
return ((*rng)() % n); | |
} | |
INSTANTIATE_CLASS(DataTransformer); | |
} // namespace caffe |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment