Skip to content

Instantly share code, notes, and snippets.

@shyoshyo
Created May 30, 2018 18:54
Show Gist options
  • Save shyoshyo/99bb2c5b333d18d4feea23d16a1db3ff to your computer and use it in GitHub Desktop.
Save shyoshyo/99bb2c5b333d18d4feea23d16a1db3ff to your computer and use it in GitHub Desktop.
def spectral_normed_weight(W, u=None, num_iters=1, update_collection=tf.GraphKeys.UPDATE_OPS, name='spectral_norm', eps=1e-12, reuse=False):
with tf.variable_scope(name, reuse=reuse):
W_shape = W.shape.as_list()
n_in = W.shape[:-1].num_elements()
n_out = W.shape[-1].value
W_reshaped = tf.reshape(W, [n_in, n_out])
if u is None: u = tf.get_variable("u", shape=[1, n_out], initializer=tf.truncated_normal_initializer(), trainable=False)
# if u is None: u = tf.Variable(tf.truncated_normal(shape=[1, n_out]), name='u', trainable=False)
# Usually num_iters = 1 will be enough
def power_iteration(i, u_i, v_i):
# new_v = u * W' = (W * u')' ~ ([n_in, n_out] * [n_out, 1])' = [1, n_in]
v_ip1 = _l2normalize(tf.reshape(tf.matmul(W_reshaped, tf.reshape(u_i, shape=[n_out, 1])), shape=[1, n_in]))
# new_u = v * W = ~ [1, n_in] * [n_in, n_out] = [1, n_out]
u_ip1 = _l2normalize(tf.matmul(v_ip1, W_reshaped))
return i + 1, u_ip1, v_ip1
_, u_final, v_final = tf.while_loop(
cond=lambda i, _1, _2: i < num_iters, body=power_iteration,
loop_vars=(tf.constant(0, dtype=tf.int32), u, tf.zeros(dtype=tf.float32, shape=[1, n_in]))
)
sigma = tf.reduce_sum(tf.matmul(v_final, W_reshaped) * (u_final))
W_bar = W_reshaped / (sigma + eps)
W_bar = tf.reshape(W_bar, W_shape)
if update_collection is not None: tf.add_to_collection(update_collection, u.assign(u_final))
return W_bar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment