Skip to content

Instantly share code, notes, and snippets.

@gbuesing
Last active August 29, 2015 14:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gbuesing/17d6528aacc556bf1de1 to your computer and use it in GitHub Desktop.
Save gbuesing/17d6528aacc556bf1de1 to your computer and use it in GitHub Desktop.
Multivariate linear regression in Ruby - adapted from example from Andrew Ng's Machine Learning Coursera class
require 'narray' # gem install narray
class LinearRegressor
attr_reader :theta, :mean, :std, :cost_history
def initialize opts = {}
@alpha = opts[:alpha] || 0.01
@iterations = opts[:iterations] || 400
end
def fit x, y
x, @mean, @std = preprocess x
y = NVector[*y]
m = x.shape[1]
@theta = NVector.float(x.shape[0])
@cost_history = []
cost = compute_cost x, y, m
@cost_history << cost
log cost, 0
1.upto(@iterations) do |i|
@theta -= (@alpha / m) * x.transpose * ((x * @theta) - y)
cost = compute_cost x, y, m
@cost_history << cost
log cost, i
end
self
end
def predict x
x, _m, _s = preprocess(x, @mean, @std)
x * @theta
end
def fit_normal x, y
x, @mean, @std = preprocess x
y = NVector[*y]
@theta = (x.transpose * x).inverse * (x.transpose * y)
self
end
private
def compute_cost x, y, m
errors = (x * @theta) - y
(1 / (2.0 * m)) * errors**2
end
def log c, i
puts "[#{i}] err: #{c.to_f.round(4)}" if i % 10 == 0
end
def preprocess x, mean = nil, std = nil
x = NMatrix.cast(x)
x_mean = mean || x.mean(1)
x_std = std || x.stddev(1)
x_std[x_std.eq(0)] = 1.0 # so we don't divide by 0
x = NMatrix.ref((NArray.ref(x) - x_mean) / x_std)
out = add_ones_column x
[out, x_std, x_mean]
end
def add_ones_column m
out = NMatrix.float(m.shape[0] + 1, m.shape[1])
out[1..m.shape[0], true] = m
out[0, true] = 1
out
end
end
require 'csv'
x, y = [], []
CSV.read('ex1data2.txt').each do |row|
x << row.slice(0,row.length-1).map(&:to_f)
y << row.last.to_f
end
reg = LinearRegressor.new
reg.fit x, y
puts "Theta:"
p reg.theta
samples = [
[2104,3],
[1600,3],
[2400,3]
]
puts "Predictions:"
p reg.predict samples
require 'gnuplot'
Gnuplot.open do |gp|
Gnuplot::Plot.new( gp ) do |plot|
plot.title "Cost history"
plot.xlabel "iteration"
plot.ylabel "error"
plot.terminal "png"
plot.output "cost_history.png"
x = 0.upto(400).to_a
y = reg.cost_history
plot.data << Gnuplot::DataSet.new( [x, y] ) do |ds|
ds.with = "lines"
ds.notitle
end
end
end
2104,3,399900
1600,3,329900
2400,3,369000
1416,2,232000
3000,4,539900
1985,4,299900
1534,3,314900
1427,3,198999
1380,3,212000
1494,3,242500
1940,4,239999
2000,3,347000
1890,3,329999
4478,5,699900
1268,3,259900
2300,4,449900
1320,2,299900
1236,3,199900
2609,4,499998
3031,4,599000
1767,3,252900
1888,2,255000
1604,3,242900
1962,4,259900
3890,3,573900
1100,3,249900
1458,3,464500
2526,3,469000
2200,3,475000
2637,3,299900
1839,2,349900
1000,1,169900
2040,4,314900
3137,3,579900
1811,4,285900
1437,3,249900
1239,3,229900
2132,4,345000
4215,4,549000
2162,4,287000
1664,2,368500
2238,3,329900
2567,4,314000
1200,3,299000
852,2,179900
1852,4,299900
1203,3,239500
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment