library(torch)
optim_torch <- function(params, fn, method, iterations = 1000, ...) {
optimizer <- do.call(paste0("optim_", method), list(params, ...))
for (i in seq_len(iterations)) {
obj_val <- fn(params)
if (i %% 100 == 0) message(as.numeric(obj_val))
optimizer$zero_grad()
obj_val$backward()
optimizer$step()
}
list(
par = params,
value = obj_val,
convergence = 1 # just iteration limit
)
}
## Rosenbrock Banana function
## from R optim docs
fr <- function(x) {
x1 <- x[1]
x2 <- x[2]
100 * (x2 - x1 * x1)^2 + (1 - x1)^2
}
optim_torch(torch_tensor(c(-1.2, 1), requires_grad = TRUE), fr,
iterations = 1000, method = "adam", lr = 0.1
)
#> 3.00256299972534
#> 1.01531004905701
#> 0.194786816835403
#> 0.0531327575445175
#> 0.0157341733574867
#> 0.00455428753048182
#> 0.00123021111357957
#> 0.000301693304209039
#> 6.5903142967727e-05
#> 1.26426457427442e-05
#> $par
#> torch_tensor
#> 0.9965
#> 0.9930
#> [ CPUFloatType{2} ]
#>
#> $value
#> torch_tensor
#> 1e-05 *
#> 1.2643
#> [ CPUFloatType{1} ]
#>
#> $convergence
#> [1] 1
Created on 2020-10-24 by the reprex package (v0.3.0)