Skip to content

Instantly share code, notes, and snippets.

@pat-alt
Created May 10, 2022 11:03
Show Gist options
  • Save pat-alt/3bd13b648c1c594a71cd23572c0fe2e1 to your computer and use it in GitHub Desktop.
Save pat-alt/3bd13b648c1c594a71cd23572c0fe2e1 to your computer and use it in GitHub Desktop.
Adapts the gradient for the counterfactual loss function to use CoutnerfactualExplanations.jl for a model trained in R.
import CounterfactualExplanations.Generators: ∂ℓ
using LinearAlgebra
# Countefactual loss:
function ∂ℓ(
generator::AbstractGradientBasedGenerator,
counterfactual_state::CounterfactualState)
M = counterfactual_state.M
nn = M.nn
x′ = counterfactual_state.x′
t = counterfactual_state.target_encoded
R"""
x <- torch_tensor($x′, requires_grad=TRUE)
output <- $nn(x)
loss_fun <- nnf_binary_cross_entropy_with_logits
obj_loss <- loss_fun(output,$t)
obj_loss$backward()
"""
grad = rcopy(R"as_array(x$grad)")
return grad
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment