Skip to content

Instantly share code, notes, and snippets.

@vsdsantos
Last active September 14, 2020 22:41
Show Gist options
  • Save vsdsantos/0d5c1f31b2adaa4540841ddaf5b2ec55 to your computer and use it in GitHub Desktop.
Save vsdsantos/0d5c1f31b2adaa4540841ddaf5b2ec55 to your computer and use it in GitHub Desktop.
function gradient_descent(f, ∇f, x0, ϵ)
e = ϵ
n = 1
t = 1
x_n = x0
while abs(e) >= ϵ
f_n = f(x_n)
grad = ∇f(x_n) # vetor do gradiente
if sqrt(sum(grad).^2) == 0 # se o gradiente for 0 você chegou ao seu destino
break
end
# vetor unitário na direção do gradiente
u_n = grad./sqrt(sum(grad.^2))
# achar o valor de t que minimiza a função
t = quadratic_interp(f, x_n, u_n)
x_n1 = x_n - t*u_n # novo X
f_n1 = f(x_n1)
e = sqrt(sum(((f_n1 - f_n)./f_n).^2))
x_n = x_n1
n = n + 1
end
return n
end
function quadratic_interp(f, x_n, u_n)
b = 1
c = 2b
if f(x_n - u_n) > f(x_n)
while f(x_n) < f(x_n - b*u_n)
b = b/2
c = 2b
end
else
while f(x_n - b*u_n) >= f(x_n - c*u_n)
b = 2*b
c = 2b
if c > 1e10
exit("help")
end
end
end
# display(b)
# os três pontos da parábola
p1 = [0, f(x_n)]
p2 = [b, f(x_n - b*u_n)]
p3 = [c, f(x_n - c*u_n)]
# coeficientes Ax^2 + Bx + C
# https://stackoverflow.com/questions/717762/how-to-calculate-the-vertex-of-a-parabola-given-three-points
denom = (p1[1] - p2[2])*(p1[1] - p3[1])*(p2[1] - p3[1])
A = (p3[1] * (p2[2] - p1[2]) + p2[1] * (p1[2] - p3[2]) + p1[1] * (p3[2] - p2[2]))
B = (p3[1]^2 * (p1[2] - p2[2]) + p2[1]^2 * (p3[2] - p1[2]) + p1[1]^2 * (p2[2] - p3[2]))
C = (p2[1] * p3[1] * (p2[1] - p3[1]) * p1[2] + p3[1] * p1[1] * (p3[1] - p1[1]) * p2[2] + p1[1] * p2[1] * (p1[1] - p2[1]) * p3[2])
t = -B / 2A # x do vértice
f_t = (C - B^2 / 4A) / denom # y do vértice
if f(x_n - b*u_n) < f_t # reavalia em b
t = b
end
return t
end
using Plots
f1(X) = X[1]^2 + X[2]^2 # x^2 + y^2
∇f1(X) = [2*X[1], 2*X[2]] # [2x, 2y]
n1(x, y, e) = gradient_descent(f1, ∇f1, [x,y], e)
x = range(-5, 5, length=100)
y = range(-5, 5, length=100)
contour(x, y, (x, y) -> n1(x, y, 0.1),
aspect_ratio = 1, fill = true, c=:matter, title=L"\epsilon = 0.1")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment