Skip to content

Instantly share code, notes, and snippets.

@jiahao
Created November 22, 2019 10:38
Show Gist options
  • Save jiahao/e6d2af718411138d9fb41ab508b54dfa to your computer and use it in GitHub Desktop.
Save jiahao/e6d2af718411138d9fb41ab508b54dfa to your computer and use it in GitHub Desktop.
Sparse logistic PCA in Julia - translated from @andland 's implementation https://github.com/andland/SparseLogisticPCA
using LinearAlgebra
using StatsBase
using StatsFuns
using NaNMath
"x->2x-1 in place"
function twoxm1!(dat; val=0.0)
@inbounds for (i,x) in enumerate(dat)
dat[i] = ifelse(isnan(x), val, 2x-1)
end
dat
end
"Perform soft shrinkage"
function softshrink!(dat, val=0)
val = abs(val)
@inbounds for (i,x) in enumerate(dat)
y = abs(x)
dat[i] = ifelse(y < val, 0, x - copysign(val, x))
end
dat
end
"""
Normalize the product A*B such that the norms of the columns of B are unity
"""
function normalize!(A::Matrix{T}, B::Matrix{T}) where T
W = sqrt.(sum(x->x^2, B, dims=1))
A .*= W
B ./= W
end
"Subtract column means"
center(dat) = dat .- mean(dat, dims=1)
function _initialize!(A0, B0, μ0, dat, k, randstart)
q = twoxm1!(dat) # forces x to be equal to θ when data is missing
n, d = size(dat)
# Initialize #
##################
if (!randstart)
μ=mean(q, dims=1)
μ=reshape(μ,d,1)
F=svd(center(q))
A=F.U[1:n,1:k]
B=F.V[1:d,1:k] * Diagonal(F.S[1:k])
else
μ=randn(d,1)
A=2rand(n,k).-1
B=2rand(d,k).-1
end
if A0 !== nothing
W = sqrt.(sum(x->x^2, A0, dims=1))
if B0 !== nothing
B = B0 .* W
end
A = A0 ./ W
end
if μ0 !== nothing
μ = μ0
end
A, B, μ, n, d, q
end
function computeX!(q, μ, A, B, X)
# n = size(A, 1)
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
mul!(X, A, B')
n, m = size(X)
@inbounds for j=1:m, i=1:n
X[i,j] = 4q[i,j]*(1 - logistic(q[i,j]*(μ[j] + X[i,j])))
end
X
end
function compute_loglike!(q, μ, A, B, X)
# loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
mul!(X, A, B')
n, m = size(X)
loglike = 0.0
@inbounds for j=1:m, i=1:n
c = log(logistic(q[i,j]*(μ[j] + X[i,j])))
loglike += ifelse(isnan(c), 0, c)
end
loglike
end
"""
From Lee, Huang, Hu (2010)
Uses the uniform bound for the log likelihood
Can only use lasso=true if λ is the same for all dimensions,
which is how this algorithm is coded
"""
function splogitpca(dat::Matrix; λ=0,k=2,verbose=false,maxiters::Int=100,convcrit=1e-5,
randstart=false,procrustes=true,lasso=true,normalize=false,
A0=nothing, B0=nothing, μ0=nothing, eps = 1e-10)
A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart)
loss_trace=zeros(maxiters)
loglike = 0.0
iters = 0
μ_prev=μ
A_prev=A
B_prev=B
X = zeros(size(A,1), size(B,1))
for m = 1:maxiters
μ_prev=μ
A_prev=A
B_prev=B
# θ=ones(n)*μ'+A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xcross=X-A * B'
# μ=(1/n*Xcross'ones(n))
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
μ += vec(sum(X, dims=1))/n
# θ=ones(n)*μ'+A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xstar=X-ones(n)*μ'
# if (procrustes)
# M=svd(Xstar * B)
# A=M.U * M.Vt
# else
# A = Matrix(qr(Xstar * pinv(B)').Q)
# end
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
if (procrustes)
M = svd(A*B'B + X*B)
A[:] = M.U * M.Vt
else
A[:] = Matrix(qr(A + X/B').Q)
end
# θ=(ones(n)*μ') + A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xstar= X - ones(n)*μ'
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
# Xstar = X + A*B'
# C = Xstar'A
C = X'A + B*A'A
if lasso
B = softshrink!(C, 4*n*λ)
# B = sign.(B_lse).*rectify.(abs.(B_lse).-4*n*λ)
else
B[:] = C ./ (1 .+ 4*n*λ*inv.(abs.(B)))
# B=abs.(B)/(abs.(B).+4*n*λ).*C
end
loglike=compute_loglike!(q, μ, A, B, X)
# loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
penalty=n*λ*sum(abs, B)
loss_trace[m]=(-loglike+penalty)/count(x->!isnan(x), dat)
iters = m
if verbose
println(m," ",(-loglike)," ",(penalty)," ",-loglike+penalty, " ", loss_trace[m], " ", convcrit)
end
#Converged?
if (m>4) && (loss_trace[m-1]-loss_trace[m])<convcrit
break
end
end
# if iters > 1 && loss_trace[iters-1]<loss_trace[iters] #This iteration doesn't count
# μ=μ_prev
# A=A_prev
# B=B_prev
# loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
# iters -= 1
# end
if (normalize)
normalize!(A, B)
end
nzeros=count(x->abs(x)<eps, B)
# BIC=-2*loglike+log(n)*(d+n*k+count(x->abs(x)>=eps, B))
BIC=-2*loglike+log(n)*(d+n*k+d*k-nzeros)
return μ, A, B, nzeros, BIC, iters, loss_trace[1:iters], λ
end
using IterativeSolvers
function _initializel!(A0, B0, μ0, dat, k, randstart)
q = twoxm1!(dat) # forces x to be equal to θ when data is missing
n, d = size(dat)
# Initialize #
##################
if (!randstart)
μ = mean(q, dims=1)
μ = reshape(μ,d,1)
F, _ = svdl(center(q), nsv=k, vecs=:both)
A=F.U[1:n,1:k]
B=F.V[1:d,1:k] * Diagonal(F.S[1:k])
else
μ=randn(d,1)
A=2rand(n,k).-1
B=2rand(d,k).-1
end
if A0 !== nothing
W = sqrt.(sum(x->x^2, A0, dims=1))
if B0 !== nothing
B = B0 .* W
end
A = A0 ./ W
end
if μ0 !== nothing
μ = μ0
end
A, B, μ, n, d, q
end
"""
My version using Iterative SVD (L for Lanczos!)
"""
function splogitpcal(dat::Matrix; λ=0,k=2,verbose=false,maxiters::Int=100,convcrit=1e-5,
randstart=false,procrustes=true,lasso=true,normalize=false,
A0=nothing, B0=nothing, μ0=nothing, eps = 1e-10)
A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart)
loss_trace=zeros(maxiters)
loglike = 0.0
iters = 0
μ_prev=μ
A_prev=A
B_prev=B
X = zeros(size(A,1), size(B,1))
for m = 1:maxiters
μ_prev=μ
A_prev=A
B_prev=B
# θ=ones(n)*μ'+A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xcross=X-A * B'
# μ=(1/n*Xcross'ones(n))
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
μ += vec(sum(X, dims=1))/n
# θ=ones(n)*μ'+A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xstar=X-ones(n)*μ'
# if (procrustes)
# M=svd(Xstar * B)
# A=M.U * M.Vt
# else
# A = Matrix(qr(Xstar * pinv(B)').Q)
# end
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
if (procrustes)
M = svd(A*B'B + X*B)
A[:] = M.U * M.Vt
else
A[:] = Matrix(qr(A + X/B').Q)
A[:] = Matrix(qr(A + X/B').Q)
end
# θ=(ones(n)*μ') + A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xstar= X - ones(n)*μ'
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
# Xstar = X + A*B'
# C = Xstar'A
C = X'A + B*A'A
if lasso
B = softshrink!(C, 4*n*λ)
# B = sign.(B_lse).*rectify.(abs.(B_lse).-4*n*λ)
else
B[:] = C ./ (1 .+ 4*n*λ*inv.(abs.(B)))
# B=abs.(B)/(abs.(B).+4*n*λ).*C
end
loglike=compute_loglike!(q, μ, A, B, X)
# loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
penalty=n*λ*sum(abs, B)
loss_trace[m]=(-loglike+penalty)/count(x->!isnan(x), dat)
iters = m
if verbose
println(m," ",(-loglike)," ",(penalty)," ",-loglike+penalty, " ", loss_trace[m], " ", convcrit)
end
#Converged?
if (m>4) && (loss_trace[m-1]-loss_trace[m])<convcrit
break
end
end
# if iters > 1 && loss_trace[iters-1]<loss_trace[iters] #This iteration doesn't count
# μ=μ_prev
# A=A_prev
# B=B_prev
# loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
# iters -= 1
# end
if (normalize)
normalize!(A, B)
end
nzeros=count(x->abs(x)<eps, B)
# BIC=-2*loglike+log(n)*(d+n*k+count(x->abs(x)>=eps, B))
BIC=-2*loglike+log(n)*(d+n*k+d*k-nzeros)
return μ, A, B, nzeros, BIC, iters, loss_trace[1:iters], λ
end
function splogitpcacoords(dat; λs=exp10.(range(-2,stop=2,length=10)),k=2,verbose=false,maxiters=100,convcrit=1e-5,
randstart=false,normalize=false,
A0=nothing,B0=nothing,μ0=nothing,eps = 1e-10)
# From Lee, Huang (2013)
# Uses the uniform bound for the log likelihood
# Initialize #
##################
A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart)
BICs=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k))
zeros_mat=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k))
iters=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k))
θ=ones(n)*μ'+A * B'
X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
Xcross=X-A * B'
μ=(1/n*(Xcross)' * ones(n))
loglike = 0.0
iters = 0
for m in 1:k
A_prev=A
B_prev=B
θ=ones(n)*μ'+A * B'
X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
Xm=X-(ones(n)*μ')+A[:,-m] * B[:,-m]'
Bms=fill(NaN,d,length(λs))
Ams=fill(NaN,n,length(λs))
for λ in λs
for i in 1:maxiters
if (sum(x->x^2, B[:,m])==0)
A[:,m]=Xm * B[:,m]
break
end
A[:,m]=Xm * B[:,m]/sum(x->x^2, B[:,m])
A[:,m]=A[:,m]/sqrt(sum(x->x^2, A[:,m]))
B_lse=Xm'*A[:,m]
B[:,m]=sign.(B_lse).*rectify.(abs.(B_lse).-λ)
loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
penalty=0.25*λ*sum(abs, B[:,m])
loss=(-loglike+penalty)/count(x->!isnan(x), dat)
iters = m
if verbose
println(m," ",(-loglike)," ",(penalty)," ",-loglike+penalty)
end
#Converged?
if (i>4) && (prev_loss-loss)/prev_loss<convcrit
break
end
prev_loss=loss
end
Bms[:,λ==λs]=B[:,m]/ifelse(sum(B[:,m]^2)==0,1,sqrt(sum(B[:,m]^2)))
Ams[:,λ==λs]=Xm * Bms[:,λ==λs]/ifelse(sum(Bms[:,λ==λs]^2)==0,1,sum(Bms[:,λ==λs]^2))
BICs[λ==λs,m]=-2*loglike+log(n*d)*(sum(abs.(B).>=eps))
zeros_mat[λ==λs,m]=sum(abs.(B[:,m]).<eps)
iters[λ==λs,m]=i
end
B[:,m]=Bms[:,which.min(BICs[:,m])]
A[:,m]=Ams[:,which.min(BICs[:,m])]
end
if (normalize)
normalize!(A, B)
end
nzeros=sum(abs.(B).<eps)
BIC=-2*loglike+log(n*d)*(sum(abs.(B).>=eps))
return μ, A, B, nzeros, zeros_mat, BICs, BIC,λs, iters
end
# Simple tests
let
n, d, k = 10, 4, 2
# A0 = rand(n, k)
# B0 = rand(d, k)
A0 = [0.004757656 0.5484724
0.097842901 0.7151051
0.847809061 0.9244518
0.587115318 0.1226018
0.645016418 0.9158401
0.512536788 0.7956434
0.630291718 0.7374858
0.744668306 0.9878830
0.824168577 0.9300954
0.500897135 0.2217333]
B0= [
0.65122711 0.5302900
0.12274696 0.2805810
0.09102886 0.1355389
0.91832927 0.2721085]
X = A0*B0'
let
A, B, μ, n, d, q = _initialize!(nothing, nothing, nothing, copy(X), k, false)
A1 = [
-0.5388835 0.333868208
-0.3754175 0.428292562
0.3857167 -0.028172460
-0.2365052 -0.652655997
0.2034485 0.146812728
0.0281748 0.138057987
0.1030301 -0.031392469
0.3262791 0.133665953
0.3677298 -0.000633053
-0.2635728 -0.467843459
]
B1 = [
1.7012514 0.2104329
0.5569920 0.2680516
0.3181553 0.1070674
1.7766535 -0.3047110
]
@assert norm(A - A1) < 1e-6
@assert norm(B - B1) < 1e-6
end
let
μ, A, B, nzeros, BIC, iters, losses, λ = splogitpca(copy(X))
μ1 = [0.11143104907230131, -6.564739360771476, -5.150479248661679, 3.1076465866696625]
A1 = [-0.4711911569853924 -0.1489574609442475; -0.6112679385689244 -0.21907436712904096; 0.17677392090006658 -0.32920587480319885; -0.12646233029137596 -0.5057072366168655; 0.21416589014147108 -0.29440618782253014; 0.29999916840057045 -0.2510075989146518; 0.25869949074847665 -0.23679822898302308; 0.18476194498116696 -0.33697878594619274; 0.17935771304814505 -0.3272536782315915; -0.29663199282213143 -0.3732276622116619]
B1 = [29.334390862738818 3.0935769680639233; -3.1451796495190014 17.718901626804065; -0.8610989625744958 11.142872726019872; 20.653625916707483 -11.121500611519389]
nzeros1 = 0
BIC1 = 78.09400917615692
iters1 = 100
losses1 = [0.3192568800107156, 0.24977073285809892, 0.21330490458541732, 0.1899828633617256, 0.1735122256779609, 0.16114605957633704, 0.15146109632832494, 0.14363527557579114, 0.1371565015500379, 0.13168741770455047, 0.12699595883671305, 0.12291687282860156, 0.11932909880827051, 0.11614181739372556, 0.11328550369088539, 0.11070599638614712, 0.10836045395242884, 0.10621452962566721, 0.10424035540856966, 0.1024150761609641, 0.10071976572936119, 0.09913861346320799, 0.09765830536474736, 0.0962675475009829, 0.09495669485112829, 0.09371745929147622, 0.09254267767201296, 0.09142612601238616, 0.09036236944427692, 0.08934664011448008, 0.08837473714519413, 0.08744294413267219, 0.08654796069463733, 0.08568684534937958, 0.08485696759454661, 0.084055967500555, 0.08328172147765685, 0.08253231314263992, 0.08180600841968862, 0.08110123417396806, 0.08041655980631743, 0.07975068134082348, 0.07910240761983256, 0.07847064828761784, 0.07785440329787183, 0.07725275372407196, 0.07666485368762423, 0.07608992324812383, 0.07552724212433402, 0.07497614413457007, 0.07443601226186682, 0.07390627426323096, 0.07338639875394007, 0.07287589170764217, 0.0723742933212689, 0.07188117520075957, 0.07139613782951738, 0.07091880828656347, 0.07044883818565822, 0.06998590181034249, 0.06952969442301517, 0.06907993072888352, 0.06863634347797201, 0.06819868219040548, 0.0677667119919442, 0.06734021254827935, 0.0669189770879308, 0.06650281150475232, 0.06609153353207073, 0.0656849719813751, 0.06528296603925979, 0.06488536461701414, 0.06449202574786106, 0.06410281602738557, 0.0637176100931686, 0.06333629014006628, 0.0629587454679458, 0.0625848720590257, 0.06221457218226193, 0.06184775402248566, 0.061484331332236145, 0.06112422310443886, 0.060767353264271115, 0.06041365037872024, 0.060063047382495384, 0.059715481319083705, 0.05937089309586481, 0.059029227252303126, 0.05869043174033748, 0.058354457716168026, 0.05802125934272438, 0.057690793602161984, 0.0573630201177989, 0.05703790098495918, 0.05671540061023694, 0.05639548555874026, 0.05607812440891309, 0.05576328761456515, 0.05545094737377308, 0.055141077504343126]
λ1 = 0
@assert norm(μ - μ1) < 1e-10
@assert norm(A - A1) < 1e-10
@assert norm(B - B1) < 1e-10
@assert nzeros == nzeros1
@assert norm(BIC - BIC1) < 1e-10
@assert iters == iters1
@assert norm(losses - losses1) < 1e-10
@assert norm(λ - λ1) < 1e-10
end
let
μ, A, B, nzeros, BIC, iters, losses, λ = splogitpca(copy(X), procrustes = false)
μ1 = [0.15487984167564228, -6.568819637068933, -5.153303852683945, 3.1354113472419947]
A1 = [-0.4713085616574335 -0.1486499336486215; -0.6116225488233055 -0.21892149365503077; 0.17668326805953308 -0.3291070254305201; -0.12676563488667017 -0.5058256360958331; 0.2136841499552063 -0.2947788948414918; 0.29919833886531016 -0.25214788479509065; 0.2579814533003914 -0.2376067790175394; 0.18463978727899122 -0.3369513463746447; 0.17923695827317881 -0.3271853654971497; -0.2975682504108814 -0.37187223776149625]
B1 = [29.36371490195748 3.3473468390064522; -3.3541459487802157 17.669448756051334; -0.9910645451446817 11.124796865100537; 20.688053338772107 -10.876220678727316]
nzeros1 = 0
BIC1 = 78.09522219179412
iters1 = 100
losses1 = [0.31927078677106835, 0.24979181872115347, 0.2133183230695471, 0.18999383437616968, 0.1735228915110179, 0.16115692983943103, 0.1514721476812388, 0.14364636426924043, 0.13716749159847258, 0.1316982099026348, 0.12700648917113633, 0.12292710375757314, 0.11933901071427137, 0.11615140198242953, 0.11329475936292606, 0.11071492516618249, 0.10836905953592055, 0.10622281619007441, 0.10424832693190614, 0.10242273606938757, 0.1007271167532976, 0.09914565762194505, 0.09766504402696699, 0.09627398148718055, 0.09496282455615426, 0.09372328481354868, 0.09254819893988216, 0.09143134290580095, 0.09036728190515164, 0.08935124824758726, 0.08837904130810297, 0.08744694501519108, 0.08655165938832812, 0.08569024340786899, 0.08486006708542487, 0.08405877104965467, 0.08328423230647461, 0.08253453509960995, 0.08180794600595748, 0.08110289256425271, 0.08041794486536624, 0.07975179963593723, 0.0791032664298458, 0.07847125560869023, 0.07785476784639267, 0.0772528849369398, 0.07666476172012959, 0.07608961896963139, 0.07552673711193403, 0.07497545066484831, 0.07443514330092219, 0.07390524345505597, 0.07338522040726606, 0.07287458078134124, 0.07237286540839802, 0.07187964651132406, 0.0713945251720289, 0.07091712904846326, 0.07044711031267445, 0.06998414378485004, 0.06952792524146519, 0.06907816987837065, 0.06863461091200698, 0.06819699830396261, 0.0677650975958532, 0.06733868884303228, 0.06691756563697579, 0.06650153420734686, 0.06609041259576673, 0.06568402989420989, 0.06528222554172623, 0.06488484867388328, 0.06449175751993028, 0.06410281884322391, 0.06371790742093102, 0.06333690555944513, 0.06295970264232852, 0.06258619470792325, 0.06221628405407158, 0.06184987886764911, 0.061486892876849705, 0.061127245024371245, 0.060770859159838375, 0.060417663749967876, 0.06006759160512929, 0.059720579621091584, 0.05937656853486391, 0.059035502693647504, 0.05869732983601197, 0.05836200088449596, 0.05802946974890764, 0.05769969313967223, 0.057372630390632874, 0.057048243290767894, 0.056726495924335395, 0.05640735451900135, 0.05609078730154475, 0.05577676436076882, 0.055465257517278245, 0.05515624019980808]
λ1 = 0
@assert norm(μ - μ1) < 1e-10
@assert norm(A - A1) < 1e-10
@assert norm(B - B1) < 1e-10
@assert nzeros == nzeros1
@assert norm(BIC - BIC1) < 1e-10
@assert iters == iters1
@assert norm(losses - losses1) < 1e-10
@assert norm(λ - λ1) < 1e-10
end
let
A1 = copy(A0)
B1 = copy(B0)
normalize!(A1, B1)
@assert norm(A0*B0' - A1*B1') < 1e-10
@assert norm(norm(B1[:,1]) - 1) < 1e-10
@assert norm(norm(B1[:,2]) - 1) < 1e-10
end
end
# using Profile
# Profile.clear()
# X=copy(X0)
# @profile q=splogitpca(X) #1.16 seconds
# Profile.print()
#
#R: 7.65368 secs
n,m=2000,1600
X0=rand(n,m)
X=copy(X0)
@time q=splogitpca(X)
X=copy(X0)
@time q=splogitpcal(X)
# X=copy(X0)
# @time q=splogitpca(X, procrustes=false)
# X=copy(X0)
# @time q=splogitpca(X, procrustes=false, lasso=false)
# X=copy(X0)
# @time q=splogitpca(X, lasso=false)
#2.6s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment