This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Stochastic gradient descent | |
#' | |
#' @param X Feature matrix. | |
#' @param y Vector containing training labels. | |
#' @param eta Learning rate. | |
#' @param n_iter Maximum number of iterations. | |
#' @param w_init Initial parameter values. | |
#' @param save_steps Boolean checking if coefficients should be saved at each step. | |
#' | |
#' @return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' REVISE algoritm - a simplified version | |
#' | |
#' @param classifier The fitted classifier. | |
#' @param x_star Attributes of individual seeking individual recourse. | |
#' @param eta Learning rate. | |
#' @param lambda Regularization parameter. | |
#' @param n_iter Maximum number of operations. | |
#' @param save_steps Boolean indicating if intermediate steps should be saved. | |
#' | |
#' @return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
logit <- function(X, y, beta_0=NULL, tau=1e-9, max_iter=10000) { | |
if(!all(X[,1]==1)) { | |
X <- cbind(1,X) | |
} | |
p <- ncol(X) | |
n <- nrow(X) | |
# Initialization: ---- | |
if (is.null(beta_0)) { | |
beta_latest <- matrix(rep(0, p)) # naive first guess | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Loss: | |
function 𝓁(w,w_0,H_0,X,y) | |
N = length(y) | |
D = size(X)[2] | |
μ = sigmoid(w,X) | |
Δw = w-w_0 | |
l = - ∑( y[n] * log(μ[n]) + (1-y[n]) * log(1-μ[n]) for n=1:N) + 1/2 * Δw'H_0*Δw | |
return l | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Newton's Method | |
function arminjo(𝓁, g_t, θ_t, d_t, args, ρ, c=1e-4) | |
𝓁(θ_t .+ ρ .* d_t, args...) <= 𝓁(θ_t, args...) .+ c .* ρ .* d_t'g_t | |
end | |
function newton(𝓁, θ, ∇𝓁, ∇∇𝓁, args; max_iter=100, τ=1e-5) | |
# Intialize: | |
converged = false # termination state | |
t = 1 # iteration count | |
θ_t = θ # initial parameters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import libraries. | |
using Flux, Plots, Random, PlotThemes, Statistics, BayesLaplace | |
theme(:wong) | |
# Toy data: | |
xs, y = toy_data_linear(100) | |
X = hcat(xs...); # bring into tabular format | |
data = zip(xs,y) | |
# Neural network: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import libraries. | |
using Flux, Plots, Random, PlotThemes, Statistics, BayesLaplace | |
theme(:wong) | |
# Toy data: | |
xs, y = toy_data_linear(100) | |
X = hcat(xs...); # bring into tabular format | |
data = zip(xs,y) | |
# Build MLP: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using CounterfactualExplanations, CounterfactualExplanations.Models | |
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend | |
# Step 1) | |
struct TorchNetwork <: Models.AbstractFittedModel | |
nn::Any | |
end | |
# Step 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import CounterfactualExplanations.Generators: ∂ℓ | |
using LinearAlgebra | |
# Countefactual loss: | |
function ∂ℓ( | |
generator::AbstractGradientBasedGenerator, | |
counterfactual_state::CounterfactualState) | |
M = counterfactual_state.M | |
nn = M.nn | |
x′ = counterfactual_state.x′ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Simple | |
"The `SimpleInductiveClassifier` is the simplest approach to Inductive Conformal Classification. Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset." | |
mutable struct SimpleInductiveClassifier{Model <: Supervised} <: ConformalSet | |
model::Model | |
coverage::AbstractFloat | |
scores::Union{Nothing,AbstractArray} | |
heuristic::Function | |
train_ratio::AbstractFloat | |
end |
OlderNewer