Skip to content

Instantly share code, notes, and snippets.

@FloydHsiu
Last active June 8, 2020 03:40
Show Gist options
  • Save FloydHsiu/828eea345e1ca6950e05bb42f0a75b50 to your computer and use it in GitHub Desktop.
Save FloydHsiu/828eea345e1ca6950e05bb42f0a75b50 to your computer and use it in GitHub Desktop.
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import layers
from tensorflow.python.keras import initializers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
class SpectralNormalization(layers.Wrapper):
"""
Attributes:
layer: tensorflow keras layers (with kernel attribute)
"""
def __init__(self, layer, **kwargs):
super(SpectralNormalization, self).__init__(layer, **kwargs)
def build(self, input_shape):
"""Build `Layer`"""
if not self.layer.built:
self.layer.build(input_shape)
if not hasattr(self.layer, 'kernel'):
raise ValueError(
'`SpectralNormalization` must wrap a layer that'
' contains a `kernel` for weights')
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.u = self.add_variable(
shape=tuple([1, self.w_shape[-1]]),
initializer=initializers.TruncatedNormal(stddev=0.02),
name='sn_u',
trainable=False,
dtype=dtypes.float32)
super(SpectralNormalization, self).build()
@def_function.function
def call(self, inputs, training=None):
"""Call `Layer`"""
if training is None:
training = K.learning_phase()
if training==True:
# Recompute weights for each forward pass
self._compute_weights()
output = self.layer(inputs)
return output
def _compute_weights(self):
"""Generate normalized weights.
This method will update the value of self.layer.kernel with the
normalized value, so that the layer is ready for call().
"""
w_reshaped = array_ops.reshape(self.w, [-1, self.w_shape[-1]])
eps = 1e-12
_u = array_ops.identity(self.u)
_v = math_ops.matmul(_u, array_ops.transpose(w_reshaped))
_v = _v / math_ops.maximum(math_ops.reduce_sum(_v**2)**0.5, eps)
_u = math_ops.matmul(_v, w_reshaped)
_u = _u / math_ops.maximum(math_ops.reduce_sum(_u**2)**0.5, eps)
self.u.assign(_u)
sigma = math_ops.matmul(math_ops.matmul(_v, w_reshaped), array_ops.transpose(_u))
self.layer.kernel.assign(self.w / sigma)
def compute_output_shape(self, input_shape):
return tensor_shape.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment