Skip to content

Instantly share code, notes, and snippets.

@dfalbel
Created January 28, 2021 20:42
Show Gist options
  • Save dfalbel/10f2fca89dd1e7713be62785435e9064 to your computer and use it in GitHub Desktop.
Save dfalbel/10f2fca89dd1e7713be62785435e9064 to your computer and use it in GitHub Desktop.
torch for R examples
library(torch)
library(ggplot2)
# we want to find the minimum of this function
# using the gradient descent.
f <- function(x) {
x^2 - x
}
p <- qplot(x = NULL, geom = "point") +
stat_function(fun = f) +
xlim(-2, 2)
p
# for the gradient descent to work we need to know how
# to compute derivatives, for example:
# f(x) = x^2 - x
df_dx <- function(x) {
2*x - 1
}
x <- -2
learning_rate <- 0.1
segments <- tibble::tibble(x = c(), xend = c())
for (t in 1:15) {
xstart <- x
x <- x - learning_rate * df_dx(x)
xend <- x
segments <- segments %>%
tibble::add_row(x = xstart, xend = xend)
}
p +
geom_segment(
data = segments,
aes(x = x, y = f(x), xend = xend, yend = f(xend)),
arrow = arrow(length = unit(0.2, "cm")),
color = "red"
)
library(torch)
library(ggplot2)
# we want to find the minimum of this function
# using the gradient descent.
f <- function(x) {
x^2 - x
}
p <- qplot(x = NULL, geom = "point") +
stat_function(fun = f) +
xlim(-2, 2)
p
# for the gradient descent to work we need to know how
# to compute derivatives, for example:
# f(x) = x^2 - x
df_dx <- function(x) {
2*x - 1
}
# in torch this done with:
x <- torch_tensor(-2, requires_grad = TRUE)
y <- f(x)
y$backward()
x$grad$zero_() #<- this gives us df/dx
# gradient descent
# x[t+1] = x[t] - learning_rate * df/dx[t]
learning_rate <- 0.1
segments <- tibble::tibble(x = c(), xend = c())
for (t in 1:15) {
xstart <- as.numeric(x)
x$grad$zero_()
y <- f(x)
y$backward()
with_no_grad({
x$sub_(learning_rate*x$grad)
})
xend <- as.numeric(x)
segments <- segments %>%
tibble::add_row(x = xstart, xend = xend)
}
p +
geom_segment(
data = segments,
aes(x = x, y = f(x), xend = xend, yend = f(xend)),
arrow = arrow(length = unit(0.2, "cm")),
color = "red"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment