Skip to content

Instantly share code, notes, and snippets.

@jrevels
Last active August 29, 2015 14:26
Show Gist options
  • Save jrevels/77d224ed856f36060a0e to your computer and use it in GitHub Desktop.
Save jrevels/77d224ed856f36060a0e to your computer and use it in GitHub Desktop.
using ForwardDiff
function check_failure_ratio(test_hessf::Function, input::Array, original::Array, max_iter=10e6)
bad_versions = 0
for iter = 1:max_iter
if maximum(abs(test_hessf(input) - original)) > 1
bad_versions += 1;
end
end
bad_ratio = bad_versions/max_iter
println("Ratio of failures (0.0 means no failures): ", bad_ratio)
end
#################
# First example #
#################
function my_fun2{T<:Number}(y::Array{T})
@assert length(y) == 1
atan(y[1])
end
my_hess2 = hessian_func(my_fun2, Partials{1,Float64}, mutates=false)
y = [4.]
original = my_hess2(y)
check_failure_ratio(my_hess2, y, original)
##################
# Second example #
##################
indices = Int64[3, 2, 1]
coeffs = [10., 100., 1000.]
function my_fun{T<:Number}(y::Array{T})
@assert length(y) == 3
T[coeffs[1] * y[indices[3]], coeffs[2] * y[indices[2]] ^ 2, coeffs[2] * atan(y[indices[3]])]
end
function get_hess_func_vec{K,T}(my_fun_arg::Function, P::Type{Partials{K,T}})
[hessian_func(y -> my_fun_arg(y)[k], P, mutates=false) for k=1:K]
end
my_hess = get_hess_func_vec(my_fun, Partials{3,Float64})[3];
y = [1., 2., 3.]
original = my_hess(y);
check_failure_ratio(my_hess, y, original)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment