Skip to content

Instantly share code, notes, and snippets.

@bschneidr
Created August 19, 2019 02:19
Show Gist options
  • Save bschneidr/61eb2ed5bf0684ba475b987019b7a110 to your computer and use it in GitHub Desktop.
Save bschneidr/61eb2ed5bf0684ba475b987019b7a110 to your computer and use it in GitHub Desktop.
Function to get Hessian in a variety of output formats
# Define the function
get_hessian <- function(f, as_matrix = FALSE, eval_at = NULL) {
fn_inputs <- all.vars(f); names(fn_inputs) <- fn_inputs
n_inputs <- length(fn_inputs)
# Obtain the symbolic Hessian as a nested list
result <- lapply(fn_inputs, function(x) lapply(fn_inputs, function(x) NULL))
for (i in seq_len(n_inputs)) {
first_deriv <- D(f, fn_inputs[i])
for (j in seq_len(n_inputs)) {
second_partial_deriv <- D(first_deriv, fn_inputs[j])
result[[i]][[j]] <- second_partial_deriv
}
}
# Convert the symbolic Hessian to a character matrix
if (is.null(eval_at)) {
if (as_matrix) {
matrix_result <- matrix(as.character(diag(n_inputs)), nrow = n_inputs, ncol = n_inputs)
for (i in seq_len(n_inputs)) {
for (j in seq_len(n_inputs)) {
matrix_result[i, j] <- gsub("expression", "", format(result[[i]][[j]]), fixed = TRUE)
}
}
return(matrix_result)
} else {
return(result)
}
}
# Evaluate the Hessian at a set point if a named list is provided
if (!is.null(eval_at)) {
result_vals <- diag(n_inputs)
for (i in seq_len(n_inputs)) {
for (j in seq_len(n_inputs)) {
result_vals[i, j] <- eval(result[[i]][[j]], envir = eval_at)
}
}
return(result_vals)
}
}
# Example usages
my_fn <- expression((x^2)*(y^2))
## Get the symbolic Hessian as a character matrix
get_hessian(my_fn, as_matrix = TRUE)
## Get the symbolic Hessian as a nested list of expressions
get_hessian(my_fn, as_matrix = FALSE)
## Get the numeric Hessian from evaluating at a particular point
get_hessian(my_fn, eval_at = list(x = 2, y = 2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment