Skip to content

Instantly share code, notes, and snippets.

@darsnack
Created April 21, 2022 16:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save darsnack/090c51a81df82ce6334211272d39b0d2 to your computer and use it in GitHub Desktop.
Save darsnack/090c51a81df82ce6334211272d39b0d2 to your computer and use it in GitHub Desktop.
Port `stratifiedobs` from MLDataPattern to MLUtils
using Random: shuffle!
using Base: @nexprs, @ntuple
function _splitobs(lm::Dict{T,Vector{I}}, at::AbstractFloat) where {T,I<:Integer}
0 < at < 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)"))
n = mapreduce(length, +, values(lm))
k = length(keys(lm))
# preallocate the indices vectors
idx1 = Vector{I}()
idx2 = Vector{I}()
# sizehint will save us a few heavy memory allocations
# we specify "+ k" to deal with trailing observations when
# the number of observations from a class isn't divideable
# by "at" or "1-at"
sizehint!(idx1, ceil(Int, n * at + k))
sizehint!(idx2, ceil(Int, n * (1-at) + k))
# loop through all label indices
for indices in values(lm)
i1, i2 = splitobs(indices; at = at)
append!(idx1, i1)
append!(idx2, i2)
end
idx1, idx2
end
@generated function _splitobs(lm::Dict{T,Vector{I}}, at::NTuple{N,AbstractFloat}) where {T,I<:Integer,N}
quote
n = mapreduce(length, +, values(lm))
k = length(keys(lm))
# preallocate the indices vectors
@nexprs $(N+1) i -> idx_i = Vector{I}()
# sizehint will save us a few heavy memory allocations
# we specify "+ k" to deal with trailing observations when
# the number of observations from a class isn't divideable
# by "at" or "1-at"
@nexprs $(N) i -> sizehint!(idx_i, ceil(Int, n*at[i] + k))
sizehint!($(Symbol(:idx_, Symbol(N+1))), ceil(Int, n*(1-sum(at)) + k))
# loop through all label indices
for indices in values(lm)
tup = splitobs(indices; at = at)
@nexprs $(N+1) i -> append!(idx_i, tup[i])
end
# return a tuple of all indices vectors
@ntuple $(N+1) idx
end
end
function stratifiedobs(data, labels = [x[2] for x in eachobs(data)]; p, shuffle::Bool = true)
# The given data is always shuffled to qualify as performing
# stratified sampling without replacement.
data_shuf = shuffleobs(data)
idx_tup = _splitobs(group_indices(labels), p)
# Setting the parameter "shuffle = false" specifies that the
# classes are ordered in the resulting subsets respectively.
shuffle && foreach(x->isempty(x) || shuffle!(x), idx_tup)
return map(idx -> obsview(data_shuf, idx), idx_tup)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment