Skip to content

Instantly share code, notes, and snippets.

@maxdrohde
Created January 17, 2021 07:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxdrohde/f807b2fbeb7b66532ec8a5e7f3c6e2bf to your computer and use it in GitHub Desktop.
Save maxdrohde/f807b2fbeb7b66532ec8a5e7f3c6e2bf to your computer and use it in GitHub Desktop.
library(tidyverse)
library(glue)
library(gganimate)
grid <- seq(-10,5,length.out=1e4)
y <- (grid + 2)^2 + 3
plot_data <- tibble(grid, y)
#######
f <- function(x){(x+2)^2 + 3}
df_dx <- function(x){2*x+4}
learning_rate <- 0.1
iter <- 1
x <- 5
xs <- numeric()
ys <- numeric()
deltas <- numeric()
while(TRUE){
xs[[iter]] <- x
ys[[iter]] <- f(x)
current_grad <- df_dx(x)
delta <- -current_grad*learning_rate
deltas[[iter]] <- delta
x <- x + delta
print(glue("Iteration: {iter}"))
print(glue("x: {x}"))
print(glue("y: {f(x)}"))
print(glue("delta: {delta}"))
iter <- iter + 1
if (abs(delta)<0.001 | iter>5000) {
break
}
}
df <- tibble(x = xs, y=ys, delta = deltas, iter=1:length(xs))
print(nrow(df))
a <- df %>%
ggplot() +
aes(x=x, y=y) +
geom_point(size=2, color="#8f2726") +
geom_line(data=plot_data, aes(x=grid, y=y)) +
cowplot::theme_cowplot(font_size=14, font_family = "Lato") +
labs(title = glue("Learning rate: {learning_rate}"),
subtitle = "Iteration: {closest_state} \n x: {df$x[as.integer(closest_state)]} \n y: {df$y[as.integer(closest_state)]} \n Delta: {df$delta[as.integer(closest_state)]}") +
transition_states(iter, wrap=FALSE) +
shadow_mark(alpha=0.2, color="black", size=1) +
ease_aes('cubic-in-out')
a_rendered <- animate(a,
fps = 20,
nframes=nrow(df)*2,
detail=10,
res=300,
height=5,
width=5,
unit="in",
renderer = ffmpeg_renderer())
anim_save(animation = a_rendered, filename = "quadratic_GD1.mp4")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment