Last active
January 8, 2017 13:27
-
-
Save vadimkantorov/68e99efe00b990a53ab7e192f0807278 to your computer and use it in GitHub Desktop.
Torch routine and module for whitening, supporting large matrices, e.g. 5K x 10K
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Produces mean and the whitening transform of matrix x[n, d] holding n data points of dimension k, using rank k approximation | |
-- Auxiliary gaussian matrix g[d, f * k] will use f = 3 by default | |
-- Original paper: http://users.cms.caltech.edu/~jtropp/papers/HMT11-Finding-Structure-SIREV.pdf | |
function whiten(x, k, f) | |
local n, d = unpack(x:size():totable()) | |
local mean = x:mean(1) | |
local g = x.new():randn(d, math.min((f or 3) * k, x:size(2))) | |
local v, s, u = torch.svd(g:t() * (x - mean:expandAs(x)):t()) | |
local pinv_s = torch.diag(s:sub(1, k):add(1e-5):pow(-1)) | |
local w = g * v:narrow(2, 1, k) * pinv_s | |
return mean, w | |
end | |
-- n = 5000 | |
-- d = 10000 | |
-- k = 128 | |
-- x = torch.FloatTensor():randn(n, d) | |
-- mean, w = whiten(x, k) -- should take 30 seconds | |
-- Y = (x - mean:expandAs(x)) * w -- whiten | |
-- print((Y:t() * Y):diag()) -- will print a vector of 1 | |
-- makes a linear whiteninig layer subtracting the mean and rotating the data using the computed transformation, x is a training data matrix on which the mean and rotation is estimated | |
function WhiteningTransform(x, embedding_size) | |
local mean, w = whiten(x, embedding_size) | |
local transform = nn.Linear(x:size(2), embedding_size) | |
transform.weight:copy(w:t()) | |
transform.bias:copy(-mean * w) | |
return transform | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment