Skip to content

Instantly share code, notes, and snippets.

@pat-alt
Last active October 24, 2022 11:43
Show Gist options
  • Save pat-alt/6f0dfeb471dca927387a46f93ab5ee1b to your computer and use it in GitHub Desktop.
Save pat-alt/6f0dfeb471dca927387a46f93ab5ee1b to your computer and use it in GitHub Desktop.
Simple, inductive conformal classification in Julia. Code snippet from [ConformalPrediction.jl](https://github.com/pat-alt/ConformalPrediction.jl).
# Simple
"The `SimpleInductiveClassifier` is the simplest approach to Inductive Conformal Classification. Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset."
mutable struct SimpleInductiveClassifier{Model <: Supervised} <: ConformalSet
model::Model
coverage::AbstractFloat
scores::Union{Nothing,AbstractArray}
heuristic::Function
train_ratio::AbstractFloat
end
function SimpleInductiveClassifier(model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=f(y, ŷ)=1.0-ŷ, train_ratio::AbstractFloat=0.5)
return SimpleInductiveClassifier(model, coverage, nothing, heuristic, train_ratio)
end
@doc raw"""
MMI.fit(conf_model::SimpleInductiveClassifier, verbosity, X, y)
For the [`SimpleInductiveClassifier`](@ref) nonconformity scores are computed as follows:
``
S_i^{\text{CAL}} = s(X_i, Y_i) = h(\hat\mu(X_i), Y_i), \ i \in \mathcal{D}_{\text{calibration}}
``
A typical choice for the heuristic function is ``h(\hat\mu(X_i), Y_i)=1-\hat\mu(X_i)_{Y_i}`` where ``\hat\mu(X_i)_{Y_i}`` denotes the softmax output of the true class and ``\hat\mu`` denotes the model fitted on training data ``\mathcal{D}_{\text{train}}``. The simple approach only takes the softmax probability of the true label into account.
"""
function MMI.fit(conf_model::SimpleInductiveClassifier, verbosity, X, y)
# Data Splitting:
train, calibration = partition(eachindex(y), conf_model.train_ratio)
Xtrain = MLJ.matrix(X)[train,:]
ytrain = y[train]
Xcal = MLJ.matrix(X)[calibration,:]
ycal = y[calibration]
# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, MMI.reformat(conf_model.model, Xtrain, ytrain)...)
# Nonconformity Scores:
ŷ = pdf.(MMI.predict(conf_model.model, fitresult, Xcal), ycal)
conf_model.scores = @.(conf_model.heuristic(ycal, ŷ))
return (fitresult, cache, report)
end
@doc raw"""
MMI.predict(conf_model::SimpleInductiveClassifier, fitresult, Xnew)
For the [`SimpleInductiveClassifier`](@ref) prediction sets are computed as follows,
``
\hat{C}_{n,\alpha}(X_{n+1}) = \left\{y: s(X_{n+1},y) \le \hat{q}_{n, \alpha}^{+} \{S_i^{\text{CAL}}\} \right\}, \ i \in \mathcal{D}_{\text{calibration}}
``
where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data.
"""
function MMI.predict(conf_model::SimpleInductiveClassifier, fitresult, Xnew)
p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...)
L = p̂.decoder.classes
ŷ = pdf(p̂, L)
v = conf_model.scores
q̂ = Statistics.quantile(v, conf_model.coverage)
ŷ = map(x -> collect(key => 1.0-val <= q̂ ? val : missing for (key,val) in zip(L,x)),eachrow(ŷ))
return ŷ
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment