Skip to content

Instantly share code, notes, and snippets.

@rueycheng
Last active February 12, 2023 14:46
Show Gist options
  • Save rueycheng/95d5c6b9864779b9395730540d82d39d to your computer and use it in GitHub Desktop.
Save rueycheng/95d5c6b9864779b9395730540d82d39d to your computer and use it in GitHub Desktop.
diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h
index 12dbe6c..ef058af 100644
--- a/include/LightGBM/dataset.h
+++ b/include/LightGBM/dataset.h
@@ -88,6 +88,8 @@ class Metadata {
void SetLabel(const label_t* label, data_size_t len);
+ void SetOrdering(const label_t* ordering, data_size_t len);
+
void SetWeights(const label_t* weights, data_size_t len);
void SetQuery(const data_size_t* query, data_size_t len);
@@ -143,6 +145,18 @@ class Metadata {
queries_[idx] = static_cast<data_size_t>(value);
}
+ /*!
+ * \brief Get ordering, if not exists, will return nullptr
+ * \return Pointer of ordering
+ */
+ inline const label_t* ordering() const {
+ if (!ordering_.empty()) {
+ return ordering_.data();
+ } else {
+ return nullptr;
+ }
+ }
+
/*!
* \brief Get weights, if not exists, will return nullptr
* \return Pointer of weights
@@ -213,6 +227,8 @@ class Metadata {
private:
/*! \brief Load initial scores from file */
void LoadInitialScore(const char* initscore_file);
+ /*! \brief Load ordering from file */
+ void LoadOrdering();
/*! \brief Load wights from file */
void LoadWeights();
/*! \brief Load query boundaries from file */
@@ -223,10 +239,14 @@ class Metadata {
std::string data_filename_;
/*! \brief Number of data */
data_size_t num_data_;
+ /*! \brief Number of ordering, used to check correct weight file */
+ data_size_t num_ordering_;
/*! \brief Number of weights, used to check correct weight file */
data_size_t num_weights_;
/*! \brief Label data */
std::vector<label_t> label_;
+ /*! \brief Ordering data */
+ std::vector<label_t> ordering_;
/*! \brief Weights data */
std::vector<label_t> weights_;
/*! \brief Query boundaries */
@@ -243,6 +263,7 @@ class Metadata {
std::vector<data_size_t> queries_;
/*! \brief mutex for threading safe call */
std::mutex mutex_;
+ bool ordering_load_from_file_;
bool weight_load_from_file_;
bool query_load_from_file_;
bool init_score_load_from_file_;
diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp
index ccb72b9..e095997 100644
--- a/src/io/dataset.cpp
+++ b/src/io/dataset.cpp
@@ -516,6 +516,12 @@ bool Dataset::SetFloatField(const char* field_name, const float* field_data, dat
#else
metadata_.SetLabel(field_data, num_element);
#endif
+ } else if (name == std::string("ordering")) {
+ #ifdef LABEL_T_USE_DOUBLE
+ Log::Fatal("Don't support LABEL_T_USE_DOUBLE");
+ #else
+ metadata_.SetOrdering(field_data, num_element);
+ #endif
} else if (name == std::string("weight") || name == std::string("weights")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't support LABEL_T_USE_DOUBLE");
@@ -560,6 +566,13 @@ bool Dataset::GetFloatField(const char* field_name, data_size_t* out_len, const
*out_ptr = metadata_.label();
*out_len = num_data_;
#endif
+ } else if (name == std::string("ordering")) {
+ #ifdef LABEL_T_USE_DOUBLE
+ Log::Fatal("Don't support LABEL_T_USE_DOUBLE");
+ #else
+ *out_ptr = metadata_.ordering();
+ *out_len = num_data_;
+ #endif
} else if (name == std::string("weight") || name == std::string("weights")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't support LABEL_T_USE_DOUBLE");
diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp
index a73ec45..b1a8cc8 100644
--- a/src/io/metadata.cpp
+++ b/src/io/metadata.cpp
@@ -11,10 +11,12 @@
namespace LightGBM {
Metadata::Metadata() {
+ num_ordering_ = 0;
num_weights_ = 0;
num_init_score_ = 0;
num_data_ = 0;
num_queries_ = 0;
+ ordering_load_from_file_ = false;
weight_load_from_file_ = false;
query_load_from_file_ = false;
init_score_load_from_file_ = false;
@@ -24,6 +26,7 @@ void Metadata::Init(const char * data_filename, const char* initscore_file) {
data_filename_ = data_filename;
// for lambdarank, it needs query data for partition data in parallel learning
LoadQueryBoundaries();
+ LoadOrdering();
LoadWeights();
LoadQueryWeights();
LoadInitialScore(initscore_file);
@@ -72,6 +75,17 @@ void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, da
label_[i] = fullset.label_[used_indices[i]];
}
+ if (!fullset.ordering_.empty()) {
+ ordering_ = std::vector<label_t>(num_used_indices);
+ num_ordering_ = num_used_indices;
+#pragma omp parallel for schedule(static)
+ for (data_size_t i = 0; i < num_used_indices; i++) {
+ ordering_[i] = fullset.ordering_[used_indices[i]];
+ }
+ } else {
+ num_ordering_ = 0;
+ }
+
if (!fullset.weights_.empty()) {
weights_ = std::vector<label_t>(num_used_indices);
num_weights_ = num_used_indices;
@@ -171,6 +185,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
LoadQueryWeights();
queries_.clear();
}
+ // check ordering
+ if (!ordering_.empty() && num_ordering_ != num_data_) {
+ ordering_.clear();
+ num_ordering_ = 0;
+ Log::Fatal("Ordering size doesn't match data size");
+ }
+
// check weights
if (!weights_.empty() && num_weights_ != num_data_) {
weights_.clear();
@@ -196,6 +217,25 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
Log::Fatal("Cannot used query_id for parallel training");
}
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
+ // check ordering
+ if (ordering_load_from_file_) {
+ if (ordering_.size() > 0 && num_ordering_ != num_all_data) {
+ ordering_.clear();
+ num_ordering_ = 0;
+ Log::Fatal("ordering size doesn't match data size");
+ }
+ // get local ordering
+ if (!ordering_.empty()) {
+ auto old_ordering = ordering_;
+ num_ordering_ = num_data_;
+ ordering_ = std::vector<label_t>(num_data_);
+#pragma omp parallel for schedule(static)
+ for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
+ ordering_[i] = old_ordering[used_data_indices[i]];
+ }
+ old_ordering.clear();
+ }
+ }
// check weights
if (weight_load_from_file_) {
if (weights_.size() > 0 && num_weights_ != num_all_data) {
@@ -320,6 +360,28 @@ void Metadata::SetLabel(const label_t* label, data_size_t len) {
}
}
+void Metadata::SetOrdering(const label_t* ordering, data_size_t len) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ // save to nullptr
+ if (ordering == nullptr || len == 0) {
+ ordering_.clear();
+ num_ordering_ = 0;
+ return;
+ }
+ if (num_data_ != len) {
+ Log::Fatal("Length of ordering is not same with #data");
+ }
+ if (!ordering_.empty()) { ordering_.clear(); }
+ num_ordering_ = num_data_;
+ ordering_ = std::vector<label_t>(num_ordering_);
+#pragma omp parallel for schedule(static)
+ for (data_size_t i = 0; i < num_ordering_; ++i) {
+ ordering_[i] = ordering[i];
+ }
+ LoadOrdering();
+ ordering_load_from_file_ = false;
+}
+
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr
@@ -369,6 +431,28 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
query_load_from_file_ = false;
}
+void Metadata::LoadOrdering() {
+ num_ordering_ = 0;
+ std::string ordering_filename(data_filename_);
+ // default ordering file name
+ ordering_filename.append(".ordering");
+ TextReader<size_t> reader(ordering_filename.c_str(), false);
+ reader.ReadAllLines();
+ if (reader.Lines().empty()) {
+ return;
+ }
+ Log::Info("Loading ordering...");
+ num_ordering_ = static_cast<data_size_t>(reader.Lines().size());
+ ordering_ = std::vector<label_t>(num_ordering_);
+#pragma omp parallel for schedule(static)
+ for (data_size_t i = 0; i < num_ordering_; ++i) {
+ double tmp_weight = 0.0f;
+ Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
+ ordering_[i] = static_cast<label_t>(tmp_weight);
+ }
+ ordering_load_from_file_ = true;
+}
+
void Metadata::LoadWeights() {
num_weights_ = 0;
std::string weight_filename(data_filename_);
@@ -485,6 +569,10 @@ void Metadata::LoadFromMemory(const void* memory) {
num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
mem_ptr += sizeof(num_queries_);
+ // TODO: to simplify implementation we do not load ordering from external
+ // data at this stage
+ num_ordering_ = 0;
+
if (!label_.empty()) { label_.clear(); }
label_ = std::vector<label_t>(num_data_);
std::memcpy(label_.data(), mem_ptr, sizeof(label_t)*num_data_);
@@ -512,6 +600,9 @@ void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
writer->Write(&num_weights_, sizeof(num_weights_));
writer->Write(&num_queries_, sizeof(num_queries_));
writer->Write(label_.data(), sizeof(label_t) * num_data_);
+
+ // TODO: to simplify implementation we do not load ordering from external
+ // data at this stage
if (!weights_.empty()) {
writer->Write(weights_.data(), sizeof(label_t) * num_weights_);
}
diff --git a/src/objective/objective_function.cpp b/src/objective/objective_function.cpp
index 9cf030a..77d5bb6 100644
--- a/src/objective/objective_function.cpp
+++ b/src/objective/objective_function.cpp
@@ -31,6 +31,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
return new BinaryLogloss(config);
} else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(config);
+ } else if (type == std::string("multiobjlambdarank")) {
+ return new MultiObjLambdarankNDCG(config);
} else if (type == std::string("multiclass") || type == std::string("softmax")) {
return new MulticlassSoftmax(config);
} else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
@@ -70,6 +72,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
return new BinaryLogloss(strs);
} else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(strs);
+ } else if (type == std::string("multiobjlambdarank")) {
+ return new MultiObjLambdarankNDCG(strs);
} else if (type == std::string("multiclass")) {
return new MulticlassSoftmax(strs);
} else if (type == std::string("multiclassova")) {
diff --git a/src/objective/rank_objective.hpp b/src/objective/rank_objective.hpp
index 785ac89..2df7559 100644
--- a/src/objective/rank_objective.hpp
+++ b/src/objective/rank_objective.hpp
@@ -240,5 +240,233 @@ class LambdarankNDCG: public ObjectiveFunction {
double sigmoid_table_idx_factor_;
};
+/*!
+* \brief Objective function combining LambdaRank NDCG over relevance label
+* and RankNet over an ordering (e.g. negative timestamp)
+*/
+class MultiObjLambdarankNDCG: public ObjectiveFunction {
+ public:
+ explicit MultiObjLambdarankNDCG(const Config& config) {
+ sigmoid_ = static_cast<double>(config.sigmoid);
+ label_gain_ = config.label_gain;
+ // initialize DCG calculator
+ DCGCalculator::DefaultLabelGain(&label_gain_);
+ DCGCalculator::Init(label_gain_);
+ // will optimize NDCG@optimize_pos_at_
+ optimize_pos_at_ = config.max_position;
+ sigmoid_table_.clear();
+ inverse_max_dcgs_.clear();
+ if (sigmoid_ <= 0.0) {
+ Log::Fatal("Sigmoid param %f should be greater than zero", sigmoid_);
+ }
+ }
+
+ explicit MultiObjLambdarankNDCG(const std::vector<std::string>&) {
+ }
+
+ ~MultiObjLambdarankNDCG() {
+ }
+ void Init(const Metadata& metadata, data_size_t num_data) override {
+ num_data_ = num_data;
+ // get label
+ label_ = metadata.label();
+ DCGCalculator::CheckLabel(label_, num_data_);
+ // get ordering
+ ordering_ = metadata.ordering();
+ // get weights
+ weights_ = metadata.weights();
+ // get boundries
+ query_boundaries_ = metadata.query_boundaries();
+ if (query_boundaries_ == nullptr) {
+ Log::Fatal("Lambdarank tasks require query information");
+ }
+ num_queries_ = metadata.num_queries();
+ // cache inverse max DCG, avoid computation many times
+ inverse_max_dcgs_.resize(num_queries_);
+#pragma omp parallel for schedule(static)
+ for (data_size_t i = 0; i < num_queries_; ++i) {
+ inverse_max_dcgs_[i] = DCGCalculator::CalMaxDCGAtK(optimize_pos_at_,
+ label_ + query_boundaries_[i],
+ query_boundaries_[i + 1] - query_boundaries_[i]);
+
+ if (inverse_max_dcgs_[i] > 0.0) {
+ inverse_max_dcgs_[i] = 1.0f / inverse_max_dcgs_[i];
+ }
+ }
+ // construct sigmoid table to speed up sigmoid transform
+ ConstructSigmoidTable();
+ }
+
+ void GetGradients(const double* score, score_t* gradients,
+ score_t* hessians) const override {
+ #pragma omp parallel for schedule(guided)
+ for (data_size_t i = 0; i < num_queries_; ++i) {
+ GetGradientsForOneQuery(score, gradients, hessians, i);
+ }
+ }
+
+ inline void GetGradientsForOneQuery(const double* score,
+ score_t* lambdas, score_t* hessians, data_size_t query_id) const {
+ // get doc boundary for current query
+ const data_size_t start = query_boundaries_[query_id];
+ const data_size_t cnt =
+ query_boundaries_[query_id + 1] - query_boundaries_[query_id];
+ // get max DCG on current query
+ const double inverse_max_dcg = inverse_max_dcgs_[query_id];
+ // add pointers with offset
+ const label_t* label = label_ + start;
+ score += start;
+ lambdas += start;
+ hessians += start;
+ // initialize with zero
+ for (data_size_t i = 0; i < cnt; ++i) {
+ lambdas[i] = 0.0f;
+ hessians[i] = 0.0f;
+ }
+ // get sorted indices for scores
+ std::vector<data_size_t> sorted_idx;
+ for (data_size_t i = 0; i < cnt; ++i) {
+ sorted_idx.emplace_back(i);
+ }
+ std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
+ [score](data_size_t a, data_size_t b) { return score[a] > score[b]; });
+ // get best and worst score
+ const double best_score = score[sorted_idx[0]];
+ data_size_t worst_idx = cnt - 1;
+ if (worst_idx > 0 && score[sorted_idx[worst_idx]] == kMinScore) {
+ worst_idx -= 1;
+ }
+ const double wrost_score = score[sorted_idx[worst_idx]];
+ // start accmulate lambdas by pairs
+ for (data_size_t i = 0; i < cnt; ++i) {
+ const data_size_t high = sorted_idx[i];
+ const int high_label = static_cast<int>(label[high]);
+ const double high_score = score[high];
+ if (high_score == kMinScore) { continue; }
+ const double high_label_gain = label_gain_[high_label];
+ const double high_discount = DCGCalculator::GetDiscount(i);
+ double high_sum_lambda = 0.0;
+ double high_sum_hessian = 0.0;
+ for (data_size_t j = 0; j < cnt; ++j) {
+ // skip same data
+ if (i == j) { continue; }
+
+ const data_size_t low = sorted_idx[j];
+ const int low_label = static_cast<int>(label[low]);
+ const double low_score = score[low];
+ // only consider pair with different label
+ if (high_label <= low_label || low_score == kMinScore) { continue; }
+
+ const double delta_score = high_score - low_score;
+
+ const double low_label_gain = label_gain_[low_label];
+ const double low_discount = DCGCalculator::GetDiscount(j);
+ // get dcg gap
+ const double dcg_gap = high_label_gain - low_label_gain;
+ // get discount of this pair
+ const double paired_discount = fabs(high_discount - low_discount);
+ // get delta NDCG
+ double delta_pair_NDCG = dcg_gap * paired_discount * inverse_max_dcg;
+ // regular the delta_pair_NDCG by score distance
+ if (high_label != low_label && best_score != wrost_score) {
+ delta_pair_NDCG /= (0.01f + fabs(delta_score));
+ }
+ // calculate lambda for this pair
+ double p_lambda = GetSigmoid(delta_score);
+ double p_hessian = p_lambda * (2.0f - p_lambda);
+ // update
+ p_lambda *= -delta_pair_NDCG;
+ p_hessian *= 2 * delta_pair_NDCG;
+ high_sum_lambda += p_lambda;
+ high_sum_hessian += p_hessian;
+ lambdas[low] -= static_cast<score_t>(p_lambda);
+ hessians[low] += static_cast<score_t>(p_hessian);
+ }
+ // update
+ lambdas[high] += static_cast<score_t>(high_sum_lambda);
+ hessians[high] += static_cast<score_t>(high_sum_hessian);
+ }
+ // if need weights
+ if (weights_ != nullptr) {
+ for (data_size_t i = 0; i < cnt; ++i) {
+ lambdas[i] = static_cast<score_t>(lambdas[i] * weights_[start + i]);
+ hessians[i] = static_cast<score_t>(hessians[i] * weights_[start + i]);
+ }
+ }
+ }
+
+
+ inline double GetSigmoid(double score) const {
+ if (score <= min_sigmoid_input_) {
+ // too small, use lower bound
+ return sigmoid_table_[0];
+ } else if (score >= max_sigmoid_input_) {
+ // too big, use upper bound
+ return sigmoid_table_[_sigmoid_bins - 1];
+ } else {
+ return sigmoid_table_[static_cast<size_t>((score - min_sigmoid_input_) * sigmoid_table_idx_factor_)];
+ }
+ }
+
+ void ConstructSigmoidTable() {
+ // get boundary
+ min_sigmoid_input_ = min_sigmoid_input_ / sigmoid_ / 2;
+ max_sigmoid_input_ = -min_sigmoid_input_;
+ sigmoid_table_.resize(_sigmoid_bins);
+ // get score to bin factor
+ sigmoid_table_idx_factor_ =
+ _sigmoid_bins / (max_sigmoid_input_ - min_sigmoid_input_);
+ // cache
+ for (size_t i = 0; i < _sigmoid_bins; ++i) {
+ const double score = i / sigmoid_table_idx_factor_ + min_sigmoid_input_;
+ sigmoid_table_[i] = 2.0f / (1.0f + std::exp(2.0f * score * sigmoid_));
+ }
+ }
+
+ const char* GetName() const override {
+ return "multiobjlambdarank";
+ }
+
+ std::string ToString() const override {
+ std::stringstream str_buf;
+ str_buf << GetName();
+ return str_buf.str();
+ }
+
+ bool NeedAccuratePrediction() const override { return false; }
+
+ private:
+ /*! \brief Gains for labels */
+ std::vector<double> label_gain_;
+ /*! \brief Cache inverse max DCG, speed up calculation */
+ std::vector<double> inverse_max_dcgs_;
+ /*! \brief Simgoid param */
+ double sigmoid_;
+ /*! \brief Optimized NDCG@ */
+ int optimize_pos_at_;
+ /*! \brief Number of queries */
+ data_size_t num_queries_;
+ /*! \brief Number of data */
+ data_size_t num_data_;
+ /*! \brief Pointer of label */
+ const label_t* label_;
+ /*! \brief Pointer of ordering */
+ const label_t* ordering_;
+ /*! \brief Pointer of weights */
+ const label_t* weights_;
+ /*! \brief Query boundries */
+ const data_size_t* query_boundaries_;
+ /*! \brief Cache result for sigmoid transform to speed up */
+ std::vector<double> sigmoid_table_;
+ /*! \brief Number of bins in simoid table */
+ size_t _sigmoid_bins = 1024 * 1024;
+ /*! \brief Minimal input of sigmoid table */
+ double min_sigmoid_input_ = -50;
+ /*! \brief Maximal input of sigmoid table */
+ double max_sigmoid_input_ = 50;
+ /*! \brief Factor that covert score to bin in sigmoid table */
+ double sigmoid_table_idx_factor_;
+};
+
} // namespace LightGBM
#endif // LightGBM_OBJECTIVE_RANK_OBJECTIVE_HPP_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment