Skip to content

Instantly share code, notes, and snippets.

@JaeDukSeo
Created May 5, 2019 14:00
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save JaeDukSeo/1f07c8d2c0d8419c86e0a31021248270 to your computer and use it in GitHub Desktop.
Save JaeDukSeo/1f07c8d2c0d8419c86e0a31021248270 to your computer and use it in GitHub Desktop.
random_direction1 = []
random_direction2 = []
for w in copy_of_the_weights:
if w.dim() == 1:
random_direction1.append(torch.zeros_like(w))
random_direction2.append(torch.zeros_like(w))
else:
random_vector = w.clone().cpu().numpy()
random_vector1 = random_vector - random_vector.mean((2,3),keepdims=True)
random_vector2 = random_vector - random_vector.mean((0,1),keepdims=True)
random_vector2 = np.transpose(random_vector2,(2,3,0,1))
sigma1 = tf.matmul(tf.transpose(random_vector1,(0,1,3,2)),random_vector1) / random_vector1.shape[3]
sigma2 = tf.matmul(tf.transpose(random_vector2,(0,1,3,2)),random_vector2) / random_vector2.shape[3]
s1,u1,v1 = tf.linalg.svd(sigma1,False)
s2,u2,v2 = tf.linalg.svd(sigma2,False)
tmp1 = tf.matmul(u1,1/(tf.sqrt(tf.linalg.diag(s1))+1e-5))
tmp1 = tmp1 @ tf.transpose(u1,(0,1,3,2))
tmp2 = tf.matmul(u2,1/(tf.sqrt(tf.linalg.diag(s2))+1e-5))
tmp2 = tmp2 @ tf.transpose(u2,(0,1,3,2))
random_vector1 = random_vector1 @ tf.transpose(tmp1,(0,1,3,2))
random_vector2 = random_vector2 @ tf.transpose(tmp2,(0,1,3,2))
random_vector2 = tf.transpose(random_vector2,(2,3,0,1))
random_vector1 = torch.from_numpy(random_vector1.eval()).cuda()
random_vector2 = torch.from_numpy(random_vector2.eval()).cuda()
w_norm = w.view((w.shape[0],-1)) .norm(dim=(1),keepdim=True)[:,:,None,None]
d_norm1 = random_vector1.view((random_vector1.shape[0],-1)).norm(dim=(1),keepdim=True)[:,:,None,None]
d_norm2 = random_vector2.view((random_vector2.shape[0],-1)).norm(dim=(1),keepdim=True)[:,:,None,None]
random_vector1 = random_vector1 * (w_norm/(d_norm1.cuda()+1e-10))
random_vector2 = random_vector2 * (w_norm/(d_norm2.cuda()+1e-10))
print(random_vector1.shape)
print(random_vector2.shape)
random_direction1.append(random_vector1)
random_direction2.append(random_vector2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment