Skip to content

Instantly share code, notes, and snippets.

@severinson
Created September 6, 2020 15:08
Show Gist options
  • Save severinson/914c49809bc58e1fafb5a100948a6fc9 to your computer and use it in GitHub Desktop.
Save severinson/914c49809bc58e1fafb5a100948a6fc9 to your computer and use it in GitHub Desktop.
PCA via gradient sketching
using Random
using LinearAlgebra
using LowRankApprox
"""Orthogonalize and normalize the columns of A in-place."""
function orthogonal!(A)
m, n = size(A)
@inbounds for i in 1:n
@inbounds for j in 1:i-1
view(A, :, i) .-= dot(view(A, :, j), view(A, :, i)) .* view(A, :, j)
end
view(A, :, i) ./= norm(view(A, :, i))
replace!(view(A, :, i), NaN=>zero(eltype(A)))
end
A
end
"""Orthogonalize and normalize the columns of A."""
orthogonal(A) = orthogonal!(copy(A))
"""
row_partition(X, p::Integer)
Split the rows of X into p partitions. Returns an array of views into X, such
that vcat(row_partition(X, p)...) is equal to X.
"""
function row_partition(X, p::Integer)
n = size(X, 1)
c = floor(Int, n/p)
j = n - c*p
append!(
[view(X, (i-1)*c+1:i*c, :) for i in 1:(p-j)], # p-j partitions with c rows
[view(X, c*(p-j)+(i-1)*(c+1)+1:c*(p-j)+i*(c+1), :) for i in 1:j], # j partitions with c+1 rows
)
end
"""Return the angle between a and b"""
angle(a::AbstractVector, b::AbstractVector) = acos(min(max(-1.0, dot(a, b) / norm(a) / norm(b)), 1.0))
"""return the angle between w and v, accounting for their sign"""
function minangle(w, v)
θ = abs(angle(w, v))
return abs(min(θ, θ-π/2))
end
"""
pcagd!(V, X, k::Integer=size(X, 2); nsamples=Threads.nthreads(), α=1.0,
npartitions=10nsamples, niter=10, loginterval=niter)
Return a d by k matrix whose columns correspond to the top k principal
components, computed using stochastic gradient descent.
"""
function pcagd!(V::AbstractMatrix, X::AbstractMatrix;
npartitions::Integer=ceil(Int, reduce(*, size(X))/1e6),
nsamples::Integer=min(Threads.nthreads(), npartitions),
maxiter::Integer=10, atol=1e-3)
size(V, 1) == size(X, 2) || throw(DimensionMismatch("V has dimensions $(size(V)), X has dimensions $(size(X))"))
size(V, 2) <= size(X, 2) || throw(DimensionMismatch("V has dimensions $(size(V)), X has dimensions $(size(X))"))
nsamples > 0 || throw(DomainError(nsamples, "nsamples must be positive"))
npartitions > 0 || throw(DomainError(npartitions, "npartitions must be positive"))
maxiter > 0 || throw(DomainError(maxiter, "maxiter must be positive"))
n, m = size(X)
k = size(V, 2)
# Partition the input matrix
Xs = row_partition(X, npartitions)
# Allocate working memory
W = zeros(ceil(Int, n/npartitions), k) # single-threaded
∇s = [zeros(m, k) for _ in 1:npartitions]
# Angle for stopping criterium
Vend = zeros(m)
# Gradient estimate
∇ = zeros(m, k)
is = collect(1:npartitions)
for t in 1:maxiter
shuffle!(is)
for i in view(is, 1:nsamples)
# Compute the partial gradient with respect to Xs[i] (=-Xs[i]'*Xs[i]*V)
Wv = view(W, 1:size(Xs[i], 1), :)
mul!(Wv, Xs[i], V)
mul!(∇s[i], Xs[i]', Wv)
end
# Compute the gradient
∇ .= V # regularizer gradient
for ∇i in ∇s
∇ .-= ∇i
end
# Gradient step with learning rate 1
V .-= ∇
orthogonal!(V)
# Stopping criterium
θ = minangle(Vend, view(V, :, k))
if θ < atol
break
end
Vend .= view(V, :, 1)
end
V
end
pcagd(X::AbstractMatrix, k::Integer; kwargs...) = pcagd!(orthogonal!(randn(size(X, 2), k)), X; kwargs...)
"""
explained_variance(X, V)
Return the fraction of variance explained by the principal components
in V, defined as tr(V'X'XV) / tr(X'X).
"""
function explained_variance(X, V)
n, d = size(X)
_, k = size(V)
XV = X*V
num = 0.0
@inbounds for i in 1:k
for j in 1:n
num += Float64(XV[j, i])^2
end
end
den = 0.0
@inbounds for i in 1:d
for j in 1:n
den += Float64(X[j, i])^2
end
end
min(num / den, 1.0-eps(Float64))
end
function main(;n=20000, m=2000, k=round(Int, m/4), atol=1e-3)
X = randn(n, m)
# Sketched gradient descent
println("Sketched gradient descent")
@time V = pcagd(X, k, atol=atol)
println("Explained variance ", explained_variance(X, V))
println()
# LowRankApprox.jl
println("LowRankApprox.jl")
@time U, S, V = psvd(X, rank=k)
println("Explained variance ", explained_variance(X, V))
println()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment