Skip to content

Instantly share code, notes, and snippets.

@pat-alt
Created May 10, 2022 11:01
Show Gist options
  • Save pat-alt/a0e52668148f7dad5fe288f8eb4e28c1 to your computer and use it in GitHub Desktop.
Save pat-alt/a0e52668148f7dad5fe288f8eb4e28c1 to your computer and use it in GitHub Desktop.
Adapts a custom `torch` model trained in R for use with CounterfactualExplantions.jl.
using Flux
using CounterfactualExplanations, CounterfactualExplanations.Models
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend
# Step 1)
struct TorchNetwork <: Models.AbstractFittedModel
nn::Any
end
# Step 2)
function logits(M::TorchNetwork, X::AbstractArray)
nn = M.nn
y = rcopy(R"as_array($nn(torch_tensor(t($X))))")
y = isa(y, AbstractArray) ? y : [y]
return y'
end
function probs(M::TorchNetwork, X::AbstractArray)
return σ.(logits(M, X))
end
M = TorchNetwork(R"model")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment