Skip to content

Instantly share code, notes, and snippets.

@cenan
Forked from actsasgeek/line.rb
Last active August 29, 2015 14:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cenan/5954704e98f305118938 to your computer and use it in GitHub Desktop.
Save cenan/5954704e98f305118938 to your computer and use it in GitHub Desktop.
require 'generator'
samples = [
{ :xs => [ 1.0, 0.25], :y => 0.98},
{ :xs => [ 1.0, 0.49], :y => 0.82},
{ :xs => [ 1.0, 0.60], :y => 0.41},
{ :xs => [ 1.0, 0.89], :y => 0.31}
]
# line is the sum of the dot product of the weight (thetas)
# an input (xs) vectors.
def line( thetas, xs)
thetas.zip( xs).map { |t, x| t * x}.inject( :+)
end
# the error for a function, f, is the difference between the
# expected value, y, and f with a certain parameterization, thetas,
# applied to the inputs, xs.
#
# f needs to be a symbol, :function.
def error( f, thetas, in_and_out)
xs, y = in_and_out.values_at( :xs, :y)
y_hat = method( f).call( thetas, xs)
return y_hat - y
end
# because we want an overall sense of error, we need to sum up
# the errors for all samples for a given parameterization; however,
# simply adding errors would lead to no error if the first were
# -10 and the second were 10.
#
# Therefore, we square the error. Additionally, we take the average
# squared error (because the sample size is fixed, this doesn't affect
# the outcome. Finally, we take 1/2 of the value (because it makes the
# derivative nicer. Because this is a constant, it doesn't affect the
# outcome either.
def squared_error( f, thetas, data)
data.map { |datum| error( f, thetas, datum) ** 2}.inject( :+)
end
def mean_squared_error( f, thetas, data)
count = data.length()
return 0.5 * (1.0 / count) * squared_error( f, thetas, data)
end
# we want to generate a grid of potential parameter
# values and then plot out the MSE for each
# set of values.
def plot_mse_for_thetas( step, samples)
range = Generator.new do |g|
start = -3.0
stop = 3.0
current = start
while current <= stop do
g.yield current
current += step
end
end
domain = []
while range.next?
domain << range.next
end
puts "\t#{domain.join( "\t")}"
domain.each do |t0|
print "#{t0}"
domain.each do |t1|
mse = mean_squared_error( :line, [t0, t1], samples)
print "\t#{ '%.3f' % mse}"
end
puts ""
end
end
plot_mse_for_thetas( 0.40, samples)
# view LaTex here: http://www.codecogs.com/latex/eqneditor.php
# according to Andrew Ng's notes, in gradient descent, each theta should be updated
# by the following rule:
# \theta_j := \theta_j - \alpha \frac{\partial}{\partial \theta_j}MSE(\theta)
# where every $\theta_j$ should be updated simultaneously.
# the derivative of MSE with respect to \theta_j is:
# \frac{1}{m} \sum_i (f(x_i) - y_i)x_{i,j}
def calculate_gradient_mse( f, thetas, samples)
averager = 1.0 / samples.length()
gradients = []
thetas.each_with_index do |theta, index|
accum = 0.0
samples.each do |sample|
xs = *sample.values_at( :xs)
accum += error( f, thetas, sample) * xs[ index]
end
gradients << averager * accum
end
gradients
end
puts calculate_gradient_mse( :line, [-3.0, -3.0], samples).inspect
def gradient_descent( f, samples, thetas, alpha)
mse = mean_squared_error( f, thetas, samples)
while mse > 0.01
puts "current MSE is #{mse} with thetas #{thetas.inspect}"
gradients = calculate_gradient_mse( f, thetas, samples)
changes = gradients.map {|g| - alpha * g}
thetas = thetas.zip( changes).map {|a, b| a + b}
mse = mean_squared_error( f, thetas, samples)
end
thetas
end
# the true global minimum is around [1.4, -1.4]
# puts gradient_descent( :line, samples, [1.4, -1.8], 0.01).inspect()
# puts gradient_descent( :line, samples, [-3.0, 3.0], 0.01).inspect() # doesn't find global.
# puts gradient_descent( :line, samples, [3.0, -3.0], 0.01).inspect() # does find global.
@alparslankapani
Copy link

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment