Last active
November 26, 2020 21:05
-
-
Save sdewaele/296d9547ef937af413b32b370044adc7 to your computer and use it in GitHub Desktop.
Generic StatsBase.ZScoreTransform
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CUDA | |
using StatsBase | |
import StatsBase: fit, transform!, reconstruct!, mean_and_std | |
""" | |
Standardization (Z-score transformation) | |
""" | |
struct ZScoreTransformGeneric{T<:Real,U<:AbstractVector{T}} <: AbstractDataTransform | |
len::Int | |
dims::Int | |
mean::U | |
scale::U | |
function ZScoreTransformGeneric(l::Int, dims::Int, m::U, s::U) where {T<:Real, U<:AbstractVector{T}} | |
lenm = length(m) | |
lens = length(s) | |
lenm == l || lenm == 0 || throw(DimensionMismatch("Inconsistent dimensions.")) | |
lens == l || lens == 0 || throw(DimensionMismatch("Inconsistent dimensions.")) | |
new{T,U}(l, dims, m, s) | |
end | |
end | |
function Base.getproperty(t::ZScoreTransformGeneric, p::Symbol) | |
if p === :indim || p === :outdim | |
return t.len | |
else | |
return getfield(t, p) | |
end | |
end | |
function fit(::Type{ZScoreTransformGeneric}, X::AbstractMatrix{<:Real}; | |
dims::Union{Integer,Nothing}=nothing, center::Bool=true, scale::Bool=true) | |
if dims == 1 | |
n, l = size(X) | |
n >= 2 || error("X must contain at least two rows.") | |
m, s = mean_and_std(X, 1) | |
elseif dims == 2 | |
l, n = size(X) | |
n >= 2 || error("X must contain at least two columns.") | |
m, s = mean_and_std(X, 2) | |
elseif dims === nothing | |
Base.depwarn("fit(t, x) is deprecated: use fit(t, x, dims=2) instead", :fit) | |
m, s = mean_and_std(X, 2) | |
else | |
throw(DomainError(dims, "fit only accept dims to be 1 or 2.")) | |
end | |
T = eltype(X) | |
return ZScoreTransformGeneric(l, dims, (center ? vec(m) : zeros(T, 0)), | |
(scale ? vec(s) : zeros(T, 0))) | |
end | |
function fit(::Type{ZScoreTransformGeneric}, X::AbstractVector{<:Real}; | |
dims::Union{Integer,Nothing}=nothing, center::Bool=true, scale::Bool=true) | |
if dims == nothing | |
Base.depwarn("fit(t, x) is deprecated: use fit(t, x, dims=2) instead", :fit) | |
elseif dims != 1 | |
throw(DomainError(dims, "fit only accepts dims=1 over a vector. Try fit(t, x, dims=1).")) | |
end | |
T = eltype(X) | |
m, s = mean_and_std(X) | |
return ZScoreTransformGeneric(1, dims, (center ? [m] : zeros(T, 0)), | |
(scale ? [s] : zeros(T, 0))) | |
end | |
function transform!(y::AbstractMatrix{<:Real}, t::ZScoreTransformGeneric, x::AbstractMatrix{<:Real}) | |
if t.dims == 1 | |
l = t.len | |
size(x,2) == size(y,2) == l || throw(DimensionMismatch("Inconsistent dimensions.")) | |
n = size(y,1) | |
size(x,1) == n || throw(DimensionMismatch("Inconsistent dimensions.")) | |
m = t.mean | |
s = t.scale | |
if isempty(m) | |
if isempty(s) | |
if x !== y | |
copyto!(y, x) | |
end | |
else | |
broadcast!(/, y, x, s') | |
end | |
else | |
if isempty(s) | |
broadcast!(-, y, x, m') | |
else | |
broadcast!((x,m,s)->(x-m)/s, y, x, m', s') | |
end | |
end | |
elseif t.dims == 2 | |
t_ = ZScoreTransformGeneric(t.len, 1, t.mean, t.scale) | |
transform!(y', t_, x') | |
end | |
return y | |
end | |
function reconstruct!(x::AbstractMatrix{<:Real}, t::ZScoreTransformGeneric, y::AbstractMatrix{<:Real}) | |
if t.dims == 1 | |
l = t.len | |
size(x,2) == size(y,2) == l || throw(DimensionMismatch("Inconsistent dimensions.")) | |
n = size(y,1) | |
size(x,1) == n || throw(DimensionMismatch("Inconsistent dimensions.")) | |
m = t.mean | |
s = t.scale | |
if isempty(m) | |
if isempty(s) | |
if y !== x | |
copyto!(x, y) | |
end | |
else | |
broadcast!(*, x, y, s') | |
end | |
else | |
if isempty(s) | |
broadcast!(+, x, y, m') | |
else | |
broadcast!((y,m,s)->y*s+m, x, y, m', s') | |
end | |
end | |
elseif t.dims == 2 | |
t_ = ZScoreTransformGeneric(t.len, 1, t.mean, t.scale) | |
reconstruct!(x', t_, y') | |
end | |
return x | |
end | |
function mean_and_std(x::CuArray, dim::Int; corrected::Bool=true) | |
m = mean(x, dims = dim) | |
s = std(x; corrected=corrected,mean=m,dims=dim) | |
m, s | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment