Skip to content

Instantly share code, notes, and snippets.

@Athospd
Created April 15, 2023 22:03
Show Gist options
  • Save Athospd/97f16d5768ef72e87d2840f524597653 to your computer and use it in GitHub Desktop.
Save Athospd/97f16d5768ef72e87d2840f524597653 to your computer and use it in GitHub Desktop.
Ilustration of the learning rate on a model fit
model_gen <- function(lr) {
lin <- nn_linear(2, 1) #b0 e b1
opt <- torch::optim_adam(lin$parameters, lr = lr)
return(list(lin = lin, opt = opt))
}
model_fit <- function(model) {
model$opt$zero_grad()
pred <- model$lin(x) # b0 + b1*x
custo <- mse(y, pred)
custo$backward()
model$opt$step()
return(pred)
}
x <- torch_tensor(cbind(cars$speed, cars$speed^2))
y <- torch_tensor(cars$dist)$unsqueeze(2)
mse <- nn_mse_loss()
model1 <- model_gen(2)
model2 <- model_gen(0.05)
model3 <- model_gen(0.001)
for(i in 1:100) {
pred1 <- model_fit(model1)
pred2 <- model_fit(model2)
pred3 <- model_fit(model3)
Sys.sleep(0.1)
plot(cars)
lines(x[,1]$unsqueeze(2), pred1, col = "red", lwd = 4)
lines(x[,1]$unsqueeze(2), pred2, col = "royalblue", lwd = 4)
lines(x[,1]$unsqueeze(2), pred3, col = "orange", lwd = 4)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment