Created
June 14, 2016 15:26
-
-
Save denis-bz/4e51f03153a7938f762f4c0f356ddd4a to your computer and use it in GitHub Desktop.
Truncated SVD of A = D (Signal, diagonal) + Noise
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" Gavish + Donoho, Optimal Hard Threshold for Singular Values is 4 / sqrt 3, 2014, 14p | |
A = D (Signal, diagonal) + Noise | |
Atrunc = truncated SVD( A ) | |
How well does Atrunc == D + Res approximate D ? | |
|Atrunc|, |Res| increase with ntrunc | |
""" | |
# what am I missing: | |
# if one knows that Signal is diagonal, just threshold A ? | |
# See also | |
# https://en.wikipedia.org/wiki/Random_matrix ... | |
# The large-scale commercialization of random matrix methods began in 2009 :~ | |
# $sklearn/decomposition/truncated_svd.py aka latent semantic analysis (LSA) | |
from __future__ import division | |
import sys | |
import numpy as np | |
__version__ = "2016-06-14 june denis-bz-py t-online de" | |
norm = np.linalg.norm | |
np.set_printoptions( threshold=20, edgeitems=10, linewidth=140, | |
formatter = dict( float = lambda x: "%.2g" % x )) # float arrays %.2g | |
thispy = __file__.split("/")[-1] | |
def ints( x ): | |
return x.round().astype(int) | |
def norms( A ): | |
# matrix norms are not very intuitive | |
# cf. sums of squares are not very intuitive.md, 41 40 9 | |
return "F norm %4.1f trace norm %3.0f" % ( | |
norm( A, "fro" ), norm( A, "nuc" )) | |
# sqrt sum( sing^2 ) < sum sing | |
#............................................................................... | |
n = 100 | |
ntrunc = [2, 10, 23] | |
d = 4 # v sensitive, max 10^4 |normal| ~ 4 | |
rank = 2 # diag: rank/2 d, rank/2 -d | |
distrib = "normal" | |
# distrib = "laplace" | |
nbin = 10 | |
plot = 0 | |
seed = 0 | |
# to change these params, run this.py n= ... in sh or ipython | |
for arg in sys.argv[1:]: | |
exec( arg ) | |
np.random.seed(seed) | |
print "\n", 80 * "-" | |
params = "%s n %d d %.1f rank %d ntrunc %s distrib %s seed %d " % ( | |
thispy, n, d, rank, ntrunc, distrib, seed ) | |
print params | |
# diag: rank/2 d, rank/2 -d | |
diag = np.zeros( n ) | |
diag[:rank//2] = d | |
diag[rank//2:rank] = - d | |
# diag[:2] = 1.7, 2.5 # ? | |
#............................................................................... | |
D = np.diag( diag ) | |
Noise = getattr( np.random, distrib )( size=(n,n) ) / np.sqrt(n) | |
A = D + Noise | |
print "diag:", diag | |
big = np.sort( np.fabs( Noise.reshape(-1) )) [::-1] | |
print "biggest |Noise|:", big[:10] | |
print "D: ", norms( D ) | |
print "Noise: ", norms( Noise ) | |
print "A = D + Noise:", norms( A ) | |
#............................................................................... | |
U, sing, Vt = np.linalg.svd( A, full_matrices=False ) | |
print "singular values: %s ... %s" % ( | |
sing[:10], sing[-2:] ) | |
bins = sing[ : n // nbin * nbin ] .reshape( nbin, n//nbin ) .sum( axis=1 ) | |
bins *= 100 / bins.sum() | |
print "histogram %: ", ints( bins ) | |
#............................................................................... | |
def truncsvd( ntrunc ): # U sing Vt D | |
print "\nntrunc %d --" % ntrunc | |
Atrunc = np.dot( U[:,:ntrunc] * sing[:ntrunc], Vt[:ntrunc] ) | |
# bestdiag = Atrunc.diagonal() # best diagonal approx in F norm | |
print "Atrunc corner * 10:\n", ints( Atrunc[:10,:10] * 10 ) | |
print "\nHow well does Atrunc = D + Res approximate D ?" | |
Res = Atrunc - D | |
print "Atrunc:", norms( Atrunc ) | |
print "D: ", norms( D ) | |
print "Res: ", norms( Res ) | |
return Atrunc | |
Atruncs = map( truncsvd, ntrunc ) | |
if plot: | |
# show Atrunc noise vs ntrunc -- | |
from matplotlib import pyplot as pl | |
import seaborn as sns | |
fig, axes = pl.subplots( ncols=len(Atruncs), figsize=[16,6] ) | |
fig.suptitle( "Truncated SVD of diagonal + random matrices \n%s " % params, | |
multialignment="left" ) | |
for Atrunc, nt, ax in zip( Atruncs, ntrunc, axes ): | |
xlabel = "ntrunc %d" % nt | |
ax.set( xlabel=xlabel, xticks=[], yticks=[] ) | |
ax.imshow( np.sqrt( np.fabs( Atrunc )), # ? | |
cmap=pl.cm.Pastel1, | |
origin="upper", interpolation=None ) | |
if plot >= 2: | |
from bz.etc import numpyutil as nu | |
nu.savefig( "tmp.png", __file__ ) | |
pl.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment