Skip to content

Instantly share code, notes, and snippets.

@avdmitry
Created March 1, 2015 06:48
Show Gist options
  • Save avdmitry/bbddbc0be6ad564114fa to your computer and use it in GitHub Desktop.
Save avdmitry/bbddbc0be6ad564114fa to your computer and use it in GitHub Desktop.
data_transformer.cpp
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <string>
#include <vector>
#include "caffe/data_transformer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"
namespace caffe {
template<typename Dtype>
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,
Phase phase)
: param_(param), phase_(phase) {
// check if we want to use mean_file
if (param_.has_mean_file()) {
CHECK_EQ(param_.mean_value_size(), 0) <<
"Cannot specify mean_file and mean_value at the same time";
const string& mean_file = param.mean_file();
LOG(INFO) << "Loading mean file from: " << mean_file;
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
data_mean_.FromProto(blob_proto);
}
// check if we want to use mean_value
if (param_.mean_value_size() > 0) {
CHECK(param_.has_mean_file() == false) <<
"Cannot specify mean_file and mean_value at the same time";
for (int c = 0; c < param_.mean_value_size(); ++c) {
mean_values_.push_back(param_.mean_value(c));
}
}
}
template <typename Dtype>
cv::Mat DatumToCVMat(const Datum& datum) {
const string& data = datum.data();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
const bool has_uint8 = data.size() > 0;
cv::Mat img = cv::Mat(datum_height, datum_width,
CV_MAKETYPE(cv::DataDepth<Dtype>::value, datum_channels));
CHECK(img.depth() == CV_8U || img.depth() == CV_32F) << "unsupported image data type";
Dtype datum_element;
int index;
for (int h = 0; h < datum_height; ++h) {
uchar* ptr = img.ptr<uchar>(h);
float* ptrf = img.ptr<float>(h);
int img_index = 0;
for (int w = 0; w < datum_width; ++w) {
for (int c = 0; c < datum_channels; ++c) {
index = (c * datum_height + h) * datum_width + w;
if (has_uint8) {
datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[index]));
} else {
datum_element = datum.float_data(index);
}
if (img.depth()==CV_8U)
{
ptr[img_index++] = datum_element;
} else // CV_32F
{
ptrf[img_index++] = datum_element;
}
}
}
}
return img;
}
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,
Dtype* transformed_data) {
const string& data = datum.data();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
const int crop_size = param_.crop_size();
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_uint8 = data.size() > 0;
const bool has_mean_values = mean_values_.size() > 0;
CHECK_GT(datum_channels, 0);
CHECK_GE(datum_height, crop_size);
CHECK_GE(datum_width, crop_size);
Dtype* mean = NULL;
if (has_mean_file) {
CHECK_EQ(datum_channels, data_mean_.channels());
CHECK_EQ(datum_height, data_mean_.height());
CHECK_EQ(datum_width, data_mean_.width());
mean = data_mean_.mutable_cpu_data();
}
if (has_mean_values) {
CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
"Specify either 1 mean_value or as many as channels: " << datum_channels;
if (datum_channels > 1 && mean_values_.size() == 1) {
// Replicate the mean_value for simplicity
for (int c = 1; c < datum_channels; ++c) {
mean_values_.push_back(mean_values_[0]);
}
}
}
int height = datum_height;
int width = datum_width;
int h_off = 0;
int w_off = 0;
if (crop_size) {
height = crop_size;
width = crop_size;
// We only do random crop when we do training.
if (phase_ == TRAIN) {
h_off = Rand(datum_height - crop_size + 1);
w_off = Rand(datum_width - crop_size + 1);
} else {
h_off = (datum_height - crop_size) / 2;
w_off = (datum_width - crop_size) / 2;
}
}
Dtype datum_element;
int top_index, data_index;
for (int c = 0; c < datum_channels; ++c) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
if (do_mirror) {
top_index = (c * height + h) * width + (width - 1 - w);
} else {
top_index = (c * height + h) * width + w;
}
if (has_uint8) {
datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
} else {
datum_element = datum.float_data(data_index);
}
if (has_mean_file) {
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
} else {
if (has_mean_values) {
transformed_data[top_index] =
(datum_element - mean_values_[c]) * scale;
} else {
transformed_data[top_index] = datum_element * scale;
}
}
}
}
}
}
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,
Blob<Dtype>* transformed_blob) {
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();
CHECK_EQ(channels, datum_channels);
CHECK_LE(height, datum_height);
CHECK_LE(width, datum_width);
CHECK_GE(num, 1);
const int crop_size = param_.crop_size();
if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
} else {
CHECK_EQ(datum_height, height);
CHECK_EQ(datum_width, width);
}
cv::Mat cv_img = DatumToCVMat<Dtype>(datum);
Transform(cv_img, transformed_blob);
}
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
Blob<Dtype>* transformed_blob) {
const int datum_num = datum_vector.size();
const int num = transformed_blob->num();
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
CHECK_GT(datum_num, 0) << "There is no datum to add";
CHECK_LE(datum_num, num) <<
"The size of datum_vector must be no greater than transformed_blob->num()";
Blob<Dtype> uni_blob(1, channels, height, width);
for (int item_id = 0; item_id < datum_num; ++item_id) {
int offset = transformed_blob->offset(item_id);
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);
Transform(datum_vector[item_id], &uni_blob);
}
}
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,
Blob<Dtype>* transformed_blob) {
const int mat_num = mat_vector.size();
const int num = transformed_blob->num();
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
CHECK_GT(mat_num, 0) << "There is no MAT to add";
CHECK_EQ(mat_num, num) <<
"The size of mat_vector must be equals to transformed_blob->num()";
Blob<Dtype> uni_blob(1, channels, height, width);
for (int item_id = 0; item_id < mat_num; ++item_id) {
int offset = transformed_blob->offset(item_id);
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);
Transform(mat_vector[item_id], &uni_blob);
}
}
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
Blob<Dtype>* transformed_blob) {
const int img_channels = cv_img.channels();
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();
CHECK_EQ(channels, img_channels);
CHECK_GE(num, 1);
CHECK(cv_img.depth() == CV_8U || cv_img.depth() == CV_32F) << "unsupported image data type";
const int crop_size = param_.crop_size();
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_mean_values = mean_values_.size() > 0;
const bool has_min_size = param_.has_min_size();
const bool has_max_size = param_.has_max_size();
const int max_size = param_.max_size();
int min_size = param_.min_size();
CHECK_GT(img_channels, 0);
Dtype* mean = NULL;
if (has_mean_file) {
CHECK_EQ(img_channels, data_mean_.channels());
CHECK_EQ(img_height, data_mean_.height());
CHECK_EQ(img_width, data_mean_.width());
mean = data_mean_.mutable_cpu_data();
}
if (has_mean_values) {
CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<
"Specify either 1 mean_value or as many as channels: " << img_channels;
if (img_channels > 1 && mean_values_.size() == 1) {
// Replicate the mean_value for simplicity
for (int c = 1; c < img_channels; ++c) {
mean_values_.push_back(mean_values_[0]);
}
}
}
// if min_size is not set, set to crop_size
if (has_min_size) {
CHECK_GE(min_size, crop_size) << "Minimum image size must be >= crop_size";
} else {
min_size = crop_size;
}
cv::Mat cv_resized_img = cv_img;
int img_resized_width = img_width;
int img_resized_height = img_height;
int resize_size = min_size;
// We do scale augmentation at TRAIN time only
// Handle 'scale' augmentation by randomly choosing a minimum image
// size to be cropped from, between [min_size, max_size], if
// max_size and min_size are set.
if (phase_ == TRAIN && has_max_size) {
CHECK_GE(max_size, min_size);
resize_size = min_size + Rand(max_size - min_size);
}
// if either of the sides of the image are less than crop size,
// enlarge to a given size_size >= crop_size
if (img_height < img_width && img_height < resize_size) {
// resize height so it is crop_size
img_resized_height = resize_size;
img_resized_width =
ceil(static_cast<float>(resize_size*img_width)/img_height);
} else if (img_width < resize_size) {
// resize width so it is crop_size
img_resized_width = resize_size;
img_resized_height =
ceil(static_cast<float>(resize_size*img_height)/img_width);
}
if (img_width != img_resized_width
|| img_height != img_resized_height) {
// In practice Andrew Howard used CUBIC for both upsampling
// and downsampling, Googlenet used random assortment of methods.
CHECK_EQ(img_channels, 3)
<< "Currently resizing of input is only supported for 3 channel inputs";
CHECK(!has_mean_file) << "Currently resizing of input is only supported for mean values";
// and downsampling, Googlenet used random assortment of methods.
cv::resize(cv_img, cv_resized_img,
cv::Size(img_resized_width, img_resized_height), cv::INTER_CUBIC);
}
int h_off = 0;
int w_off = 0;
cv::Mat cv_cropped_img = cv_resized_img;
if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
// We only do random crop when we do training.
if (phase_ == TRAIN) {
h_off = Rand(img_resized_height - crop_size + 1);
w_off = Rand(img_resized_width - crop_size + 1);
} else {
h_off = (img_resized_height - crop_size) / 2;
w_off = (img_resized_width - crop_size) / 2;
}
cv::Rect roi(w_off, h_off, crop_size, crop_size);
cv_cropped_img = cv_resized_img(roi);
} else {
CHECK_EQ(img_resized_height, height);
CHECK_EQ(img_resized_width, width);
}
CHECK(cv_cropped_img.data);
Dtype* transformed_data = transformed_blob->mutable_cpu_data();
int top_index;
for (int h = 0; h < height; ++h) {
const uchar* ptr = cv_cropped_img.ptr<uchar>(h);
const float* ptrf = cv_cropped_img.ptr<float>(h);
int img_index = 0;
for (int w = 0; w < width; ++w) {
for (int c = 0; c < img_channels; ++c) {
if (do_mirror) {
top_index = (c * height + h) * width + (width - 1 - w);
} else {
top_index = (c * height + h) * width + w;
}
Dtype pixel;
if (cv_cropped_img.depth()==CV_8U)
{
pixel = static_cast<Dtype>(ptr[img_index++]);
} else // CV_32F
{
pixel = static_cast<Dtype>(ptrf[img_index++]);
}
if (has_mean_file) {
int mean_index = (c * img_resized_height + h_off + h)
* img_resized_width + w_off + w;
transformed_data[top_index] =
(pixel - mean[mean_index]) * scale;
} else {
if (has_mean_values) {
transformed_data[top_index] =
(pixel - mean_values_[c]) * scale;
} else {
transformed_data[top_index] = pixel * scale;
}
}
}
}
}
}
template<typename Dtype>
void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
Blob<Dtype>* transformed_blob) {
cv::Mat cv_img = BlobToCVMat<Dtype, 3>(input_blob);
Transform(cv_img, transformed_blob);
}
template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
const bool needs_rand = param_.mirror() ||
(phase_ == TRAIN && param_.crop_size());
if (needs_rand) {
const unsigned int rng_seed = caffe_rng_rand();
rng_.reset(new Caffe::RNG(rng_seed));
} else {
rng_.reset();
}
}
template <typename Dtype>
int DataTransformer<Dtype>::Rand(int n) {
CHECK(rng_);
CHECK_GT(n, 0);
caffe::rng_t* rng =
static_cast<caffe::rng_t*>(rng_->generator());
return ((*rng)() % n);
}
INSTANTIATE_CLASS(DataTransformer);
} // namespace caffe
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment