Created
November 21, 2022 04:39
-
-
Save komasaru/92f0888b18d2bb25d5649ea38d33a2f5 to your computer and use it in GitHub Desktop.
C++ source code to compute logistic regression.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "calc.hpp" | |
#include <cmath> | |
#include <iostream> | |
#include <sstream> | |
#include <vector> | |
// 定数 | |
static constexpr double kAlpha = 0.01; // 学習率 | |
static constexpr double kEps = 1.0e-12; // 閾値 | |
static constexpr unsigned int kLoop = 10000000; // 最大ループ回数 | |
static constexpr double kBeta = 5.0; // 初期値: β | |
static constexpr double kCel = 0.0; // 初期値: 交差エントロピー誤差 | |
static constexpr double kOne = 1.0; // 1 | |
/** | |
* @brief ロジスティック回帰(説明変数2個)の計算 | |
* | |
* @param[ref] パラメータ(定数・係数) ps (vector<double>) | |
* @return 真偽(bool) | |
* @retval true 成功 | |
* @retval false 失敗 | |
*/ | |
bool Calc::reg_logistic(std::vector<double>& ps) { | |
std::size_t e; // 元の数 | |
std::size_t n; // サンプル数 | |
unsigned int i; // loop インデックス | |
unsigned int j; // loop インデックス | |
double loss; // 交差エントロピー誤差 | |
double loss_pre; // 交差エントロピー誤差(退避用) | |
try { | |
// 元の数, サンプル数 | |
e = data[0].size() - 1; | |
n = data.size(); | |
// データの行列化 | |
Eigen::MatrixXd mtx(n, 3); // n x 3 行列(double) b 用 | |
for (i = 0; i < n; i++) | |
for (j = 0; j < 3; j++) mtx(i, j) = data[i][j]; | |
// β初期値 (e + 1 次元ベクトル) | |
Eigen::VectorXd bs(e + 1); | |
bs.setConstant(kBeta); | |
// X の行列 (n 行 e + 1 列) | |
// (第1列(x_0)は定数項なので 1 固定) | |
Eigen::MatrixXd xs(n, e + 1); | |
xs.setConstant(kOne); | |
for (j = 0; j < e; j++) xs.col(j + 1) = mtx.col(j); | |
// t のベクトル (n 次元ベクトル) | |
Eigen::VectorXd ts(n); | |
ts = mtx.col(e).transpose(); | |
//# 交差エントロピー誤差初期値 | |
loss = kCel; | |
// y のベクトル (n 次元ベクトル) | |
Eigen::VectorXd ys(n); | |
// dE のベクトル (e + 1 次元ベクトル) | |
Eigen::VectorXd des(e + 1); | |
for (i = 0; i < kLoop; i++) { | |
// シグモイド関数適用(予測値計算) (n 次元ベクトル) | |
if (!sigmoid(xs * bs, ys)) { | |
std::cout << "[ERROR] Failed to calculate! (in sigmoid() function)" | |
<< std::endl; | |
return EXIT_FAILURE; | |
} | |
// dE 計算 (e + 1 次元ベクトル) | |
des = (ys - ts).transpose() * xs / n; | |
// β 更新 (e + 1 次元ベクトル) | |
bs -= kAlpha * des; | |
// 前回算出交差エントロピー誤差退避 | |
loss_pre = loss; | |
// 交差エントロピー誤差計算 | |
loss = cross_entropy_loss(ts, ys); | |
// 今回と前回の交差エントロピー誤差の差が閾値以下になったら終了 | |
if (abs(loss - loss_pre)< kEps) break; | |
} | |
// 計算結果用 vector へ格納 | |
for (i = 0; i < e + 1; i++) ps.push_back(bs(i)); | |
} catch (...) { | |
return false; // 計算失敗 | |
} | |
return true; // 計算成功 | |
} | |
/** | |
* @brief シグモイド関数 | |
* | |
* @param[ref] 計算前ベクトル as (Eigen::VectorXd) | |
* @param[ref] 計算後ベクトル ys (Eigen::VectorXd) | |
* @return 真偽(bool) | |
* @retval true 成功 | |
* @retval false 失敗 | |
*/ | |
bool Calc::sigmoid(const Eigen::VectorXd& as, Eigen::VectorXd& ys) { | |
try { | |
ys = 1.0 / (1.0 + (-as).array().exp()); | |
} catch (...) { | |
return false; // 計算失敗 | |
} | |
return true; // 計算成功 | |
} | |
/** | |
* @brief 交差エントロピー誤差関数 | |
* | |
* @param[ref] 実際の値(0 or 1)ベクトル ts (Eigen::VectorXd) | |
* @param[ref] 回帰式から得られる y 値ベクトル ys (Eigen::VectorXd) | |
* @return 真偽(bool) | |
* @retval true 成功 | |
* @retval false 失敗 | |
*/ | |
double Calc::cross_entropy_loss( | |
const Eigen::VectorXd& ts, const Eigen::VectorXd& ys) { | |
double loss; | |
Eigen::VectorXd ones; // all 1.0 のベクトル | |
try { | |
ones = Eigen::VectorXd::Ones(ys.rows()); | |
loss = (ts.array() * ys.array().log() | |
+ (ones - ts).array() * (ones - ys).array().log()).sum(); | |
} catch (...) { | |
throw; | |
} | |
return loss; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef REGRESSION_LOGISTIC_CALC_HPP_ | |
#define REGRESSION_LOGISTIC_CALC_HPP_ | |
#include <Eigen/Core> // for 行列定義・行列演算 | |
#include <vector> | |
class Calc { | |
std::vector<std::vector<double>> data; // 元データ | |
bool sigmoid(const Eigen::VectorXd&, Eigen::VectorXd&); // シグモイド関数 | |
double cross_entropy_loss( | |
const Eigen::VectorXd&, const Eigen::VectorXd&); // 交差エントロピー誤差関数 | |
public: | |
Calc(std::vector<std::vector<double>>& data) : data(data) {} | |
// コンストラクタ | |
bool reg_logistic(std::vector<double>&); // ロジスティック回帰(説明変数2個)の計算 | |
}; | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "file.hpp" | |
#include <iostream> | |
#include <sstream> | |
#include <string> | |
#include <vector> | |
bool File::get_text(std::vector<std::vector<double>>& data) { | |
try { | |
// ファイル OPEN | |
std::ifstream ifs(f_data); | |
if (!ifs.is_open()) return false; // 読み込み失敗 | |
// ファイル READ | |
std::string buf; // 1行分バッファ | |
while (getline(ifs, buf)) { | |
std::vector<double> rec; // 1行分ベクタ | |
std::istringstream iss(buf); // 文字列ストリーム | |
std::string field; // 1列分文字列 | |
// 1行分文字列を1行分ベクタに追加 | |
double x, y, z; | |
while (iss >> x >> y >> z) { | |
rec.push_back(x); | |
rec.push_back(y); | |
rec.push_back(z); | |
} | |
// 1行分ベクタを data ベクタに追加 | |
if (rec.size() != 0) data.push_back(rec); | |
} | |
} catch (...) { | |
std::cerr << "EXCEPTION!" << std::endl; | |
return false; | |
} | |
return true; // 読み込み成功 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef REGRESSION_LOGISTIC_FILE_HPP_ | |
#define REGRESSION_LOGISTIC_FILE_HPP_ | |
#include <fstream> | |
#include <string> | |
#include <vector> | |
class File { | |
std::string f_data; | |
public: | |
File(std::string f_data) : f_data(f_data) {} | |
bool get_text(std::vector<std::vector<double>>&); | |
}; | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/*********************************************************** | |
ロジスティック回帰計算 | |
* 説明(独立)変数2個、目的(従属)変数1個 限定 | |
DATE AUTHOR VERSION | |
2022.11.02 mk-mode.com 1.00 新規作成 | |
Copyright(C) 2022 mk-mode.com All Rights Reserved. | |
***********************************************************/ | |
#include "calc.hpp" | |
#include "file.hpp" | |
#include <cstdlib> // for EXIT_XXXX | |
#include <iomanip> // for setprecision | |
#include <iostream> | |
#include <string> | |
#include <vector> | |
int main(int argc, char* argv[]) { | |
std::string f_data; // データファイル名 | |
std::vector<std::vector<double>> data; // データ配列 | |
unsigned int i; // loop インデックス | |
std::vector<double> ps; // 定数・係数 beta | |
try { | |
// コマンドライン引数のチェック | |
if (argc != 2) { | |
std::cerr << "[ERROR] Number of arguments is wrong!\n" | |
<< "[USAGE] ./regression_logistic <file_name>" | |
<< std::endl; | |
return EXIT_FAILURE; | |
} | |
// ファイル名取得 | |
f_data = argv[1]; | |
// データ取得 | |
File file(f_data); | |
if (!file.get_text(data)) { | |
std::cout << "[ERROR] Failed to read the file!" << std::endl; | |
return EXIT_FAILURE; | |
} | |
// データ一覧出力 | |
std::cout << "説明変数 X 説明変数 Y 目的変数 Z" << std::endl; | |
std::cout << std::fixed << std::setprecision(4); | |
for (i = 0; i < data.size(); i++) | |
std::cout << std::setw(10) << std::right << data[i][0] | |
<< " " | |
<< std::setw(10) << std::right << data[i][1] | |
<< " " | |
<< std::setw(10) << std::right << data[i][2] | |
<< std::endl; | |
// 計算 | |
Calc calc(data); | |
if (!calc.reg_logistic(ps)) { | |
std::cout << "[ERROR] Failed to calculate!" << std::endl; | |
return EXIT_FAILURE; | |
} | |
// 結果出力 | |
std::cout << "betas =" << std::endl; | |
std::cout << std::fixed << std::setprecision(8); | |
for (auto p: ps) | |
std::cout << std::setw(16) << std::right << p << std::endl; | |
} catch (...) { | |
std::cerr << "EXCEPTION!" << std::endl; | |
return EXIT_FAILURE; | |
} | |
return EXIT_SUCCESS; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment