Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
library(torch)
optim_torch <- function(params, fn, method, iterations = 1000, ...) {
  optimizer <- do.call(paste0("optim_", method), list(params, ...))
  for (i in seq_len(iterations)) {
    obj_val <- fn(params)
    if (i %% 100 == 0) message(as.numeric(obj_val))
    optimizer$zero_grad()
    obj_val$backward()
    optimizer$step()
  }
  list(
    par = params,
    value = obj_val,
    convergence = 1 # just iteration limit
  )
}

## Rosenbrock Banana function
## from R optim docs
fr <- function(x) {
  x1 <- x[1]
  x2 <- x[2]
  100 * (x2 - x1 * x1)^2 + (1 - x1)^2
}
optim_torch(torch_tensor(c(-1.2, 1), requires_grad = TRUE), fr,
  iterations = 1000, method = "adam", lr = 0.1
)
#> 3.00256299972534
#> 1.01531004905701
#> 0.194786816835403
#> 0.0531327575445175
#> 0.0157341733574867
#> 0.00455428753048182
#> 0.00123021111357957
#> 0.000301693304209039
#> 6.5903142967727e-05
#> 1.26426457427442e-05
#> $par
#> torch_tensor
#>  0.9965
#>  0.9930
#> [ CPUFloatType{2} ]
#> 
#> $value
#> torch_tensor
#> 1e-05 *
#>  1.2643
#> [ CPUFloatType{1} ]
#> 
#> $convergence
#> [1] 1

Created on 2020-10-24 by the reprex package (v0.3.0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment