Skip to content

Instantly share code, notes, and snippets.

@komasaru
Created October 11, 2022 02:33
Show Gist options
  • Save komasaru/288eaf8e38c07606aae1242933a68d41 to your computer and use it in GitHub Desktop.
Save komasaru/288eaf8e38c07606aae1242933a68d41 to your computer and use it in GitHub Desktop.
Ruby script to calculate a logistic regression.
#! /usr/local/bin/ruby
#*********************************************
# Ruby script to compute a logistic regression analysis.
# (by extending the matrix class)
#*********************************************
#
require 'matrix'
class Matrix
ALPHA = 0.01 # 学習率
EPS = 1.0e-12 # 閾値
LOOP = 10000000 # 最大ループ回数
BETA = 5.0 # 初期値: β
CEL = 0.0 # 初期値: 交差エントロピー誤差
def reg_logistic
# 元の数, サンプル数
e = self.column_size - 1
n = self.row_size
# 自身 Matrix が空の場合は例外スロー
raise "Self array is nil!" if self.empty?
# β初期値 (1 行 e + 1 列)
bs = Matrix.build(1, e + 1) { |_| BETA }
# X の行列 (n 行 e 列)
# (第1列(x_0)は定数項なので 1 固定)
xs = Matrix.hstack(Matrix.build(n, 1) { 1 }, self.minor(0, n, 0, e))
# t の行列 (n 行 1 列)
ts = self.minor(0, n, e, 1)
# 交差エントロピー誤差初期値
loss = CEL
LOOP.times do |i|
#puts "i=#{i}"
# シグモイド関数適用(予測値計算)
ys = sigmoid(xs * bs.transpose)
# dE 計算
des = (ys - ts).transpose * xs / n
# β 更新
bs -= ALPHA * des
# 前回算出交差エントロピー誤差退避
loss_pre = loss
# 交差エントロピー誤差計算
loss = cross_entropy_loss(ts, ys)
# 今回と前回の交差エントロピー誤差の差が閾値以下になったら終了
break if (loss - loss_pre).abs < EPS
end
return bs
end
private
# シグモイド関数
def sigmoid(x)
return x.map { |a| 1.0 / (1.0 + Math.exp(-a)) }
rescue => e
raise
end
# 交差エントロピー誤差関数
def cross_entropy_loss(ts, ys)
return ts.zip(ys).map { |t, y|
t * Math.log(y) + (1.0 - t) * Math.log(1.0 - y)
}.sum
rescue => e
raise
end
end
# 説明(独立)変数と目的(従属)変数
# ( e.g. n 行 3 列 (x1, x2, y) )
data = Matrix[
[30, 21, 0],
[22, 10, 0],
[26, 25, 0],
[14, 20, 0],
[ 6, 10, 1],
[ 2, 15, 1],
[ 6, 5, 1],
[10, 5, 1],
[19, 15, 1]
]
puts "data ="
data.to_a.each { |row| p row }
# ロジスティック回帰式の定数・係数計算(b0, b1, b2, ...)
puts "\nNow computing...\n\n"
reg_logistic = data.reg_logistic
puts "betas = "
p reg_logistic.to_a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment