Skip to content

Instantly share code, notes, and snippets.

@sdewaele
Last active November 26, 2020 21:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sdewaele/296d9547ef937af413b32b370044adc7 to your computer and use it in GitHub Desktop.
Save sdewaele/296d9547ef937af413b32b370044adc7 to your computer and use it in GitHub Desktop.
Generic StatsBase.ZScoreTransform
using CUDA
using StatsBase
import StatsBase: fit, transform!, reconstruct!, mean_and_std
"""
Standardization (Z-score transformation)
"""
struct ZScoreTransformGeneric{T<:Real,U<:AbstractVector{T}} <: AbstractDataTransform
len::Int
dims::Int
mean::U
scale::U
function ZScoreTransformGeneric(l::Int, dims::Int, m::U, s::U) where {T<:Real, U<:AbstractVector{T}}
lenm = length(m)
lens = length(s)
lenm == l || lenm == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
lens == l || lens == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
new{T,U}(l, dims, m, s)
end
end
function Base.getproperty(t::ZScoreTransformGeneric, p::Symbol)
if p === :indim || p === :outdim
return t.len
else
return getfield(t, p)
end
end
function fit(::Type{ZScoreTransformGeneric}, X::AbstractMatrix{<:Real};
dims::Union{Integer,Nothing}=nothing, center::Bool=true, scale::Bool=true)
if dims == 1
n, l = size(X)
n >= 2 || error("X must contain at least two rows.")
m, s = mean_and_std(X, 1)
elseif dims == 2
l, n = size(X)
n >= 2 || error("X must contain at least two columns.")
m, s = mean_and_std(X, 2)
elseif dims === nothing
Base.depwarn("fit(t, x) is deprecated: use fit(t, x, dims=2) instead", :fit)
m, s = mean_and_std(X, 2)
else
throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
end
T = eltype(X)
return ZScoreTransformGeneric(l, dims, (center ? vec(m) : zeros(T, 0)),
(scale ? vec(s) : zeros(T, 0)))
end
function fit(::Type{ZScoreTransformGeneric}, X::AbstractVector{<:Real};
dims::Union{Integer,Nothing}=nothing, center::Bool=true, scale::Bool=true)
if dims == nothing
Base.depwarn("fit(t, x) is deprecated: use fit(t, x, dims=2) instead", :fit)
elseif dims != 1
throw(DomainError(dims, "fit only accepts dims=1 over a vector. Try fit(t, x, dims=1)."))
end
T = eltype(X)
m, s = mean_and_std(X)
return ZScoreTransformGeneric(1, dims, (center ? [m] : zeros(T, 0)),
(scale ? [s] : zeros(T, 0)))
end
function transform!(y::AbstractMatrix{<:Real}, t::ZScoreTransformGeneric, x::AbstractMatrix{<:Real})
if t.dims == 1
l = t.len
size(x,2) == size(y,2) == l || throw(DimensionMismatch("Inconsistent dimensions."))
n = size(y,1)
size(x,1) == n || throw(DimensionMismatch("Inconsistent dimensions."))
m = t.mean
s = t.scale
if isempty(m)
if isempty(s)
if x !== y
copyto!(y, x)
end
else
broadcast!(/, y, x, s')
end
else
if isempty(s)
broadcast!(-, y, x, m')
else
broadcast!((x,m,s)->(x-m)/s, y, x, m', s')
end
end
elseif t.dims == 2
t_ = ZScoreTransformGeneric(t.len, 1, t.mean, t.scale)
transform!(y', t_, x')
end
return y
end
function reconstruct!(x::AbstractMatrix{<:Real}, t::ZScoreTransformGeneric, y::AbstractMatrix{<:Real})
if t.dims == 1
l = t.len
size(x,2) == size(y,2) == l || throw(DimensionMismatch("Inconsistent dimensions."))
n = size(y,1)
size(x,1) == n || throw(DimensionMismatch("Inconsistent dimensions."))
m = t.mean
s = t.scale
if isempty(m)
if isempty(s)
if y !== x
copyto!(x, y)
end
else
broadcast!(*, x, y, s')
end
else
if isempty(s)
broadcast!(+, x, y, m')
else
broadcast!((y,m,s)->y*s+m, x, y, m', s')
end
end
elseif t.dims == 2
t_ = ZScoreTransformGeneric(t.len, 1, t.mean, t.scale)
reconstruct!(x', t_, y')
end
return x
end
function mean_and_std(x::CuArray, dim::Int; corrected::Bool=true)
m = mean(x, dims = dim)
s = std(x; corrected=corrected,mean=m,dims=dim)
m, s
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment