Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created November 17, 2019 16:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/a6382f92ee0e7ff2af1f523dab73384c to your computer and use it in GitHub Desktop.
Save koshian2/a6382f92ee0e7ff2af1f523dab73384c to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.keras import backend as K
import tensorflow.keras.layers as layers
# https://github.com/IShengFang/SpectralNormalizationKeras/blob/master/SpectralNormalizationKeras.py
class ConvSN2D(layers.Conv2D):
def build(self, input_shape):
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.kernel = self.add_weight(shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
#self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
# initializer=tf.keras.initializers.RandomNormal(0, 1),
# name='sn',
# trainable=False)
self.u = tf.Variable(
tf.random.normal((tuple([1, self.kernel.shape.as_list()[-1]])), dtype=tf.float32)
, aggregation=tf.VariableAggregation.MEAN, trainable=False)
# Set input spec.
self.input_spec = layers.InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
self.built = True
def call(self, inputs, training=None):
def _l2normalize(v, eps=1e-12):
return v / (K.sum(v ** 2) ** 0.5 + eps)
def power_iteration(W, u):
#Accroding the paper, we only need to do power iteration one time.
_u = u
_v = _l2normalize(K.dot(_u, K.transpose(W)))
_u = _l2normalize(K.dot(_v, W))
return _u, _v
#Spectral Normalization
W_shape = self.kernel.shape.as_list()
#Flatten the Tensor
W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
_u, _v = power_iteration(W_reshaped, self.u)
#Calculate Sigma
sigma=K.dot(_v, W_reshaped)
sigma=K.dot(sigma, K.transpose(_u))
#normalize it
W_bar = W_reshaped / sigma
#reshape weight tensor
if training == False:
W_bar = K.reshape(W_bar, W_shape)
else:
with tf.control_dependencies([self.u.assign(_u)]):
W_bar = K.reshape(W_bar, W_shape)
outputs = K.conv2d(
inputs,
W_bar,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self.use_bias:
outputs = K.bias_add(
outputs,
self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def upsampling2d_tpu(inputs, scale=2):
x = K.repeat_elements(inputs, scale, axis=1)
x = K.repeat_elements(x, scale, axis=2)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment