Skip to content

Instantly share code, notes, and snippets.

@nwatab
Last active May 25, 2020 02:54
Show Gist options
  • Save nwatab/923c43d521223d74c8b5e055bc34309f to your computer and use it in GitHub Desktop.
Save nwatab/923c43d521223d74c8b5e055bc34309f to your computer and use it in GitHub Desktop.
UNet implementation of Matlab sample for semantic segmentation https://jp.mathworks.com/help/images/multispectral-semantic-segmentation-using-deep-learning.html?lang=en . Outputs are made on different hyperparameters.
import imageio
import numpy as np
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, Callback
from keras import optimizers
import keras.backend as K
import matplotlib.pyplot as plt
from models import Pix2Pix, SegNet, vgg19_unet, UNetMatlab
np.set_printoptions(threshold=64**4, linewidth=300)
def random_crop(image, top, left, crop_size):
bottom = top + crop_size[0]
right = left + crop_size[1]
image = image[top:bottom, left:right, :]
return image
def get_datagen(img_path, seg_path, img_size=(256, 256), batch_size=16, train=True, sample_weights=None):
img = imageio.imread(img_path, pilmode='RGB')
seg = imageio.imread(seg_path, pilmode='RGB')
seg_temp = np.copy(seg[400:2000, 270:2200, :])
img = img[400:2000, 270:2200, :]
seg = seg[400:2000, 270:2200, :2]
seg[:, :, 1] = 255 - seg[:, :, 0] # Index 0: Defect, Index 1: Background
img = img.astype(np.float)
seg = seg.astype(np.float)
img /= 255.
seg /= 255.
imgs = []
segs = []
h, w, _ = img.shape
while True:
# Crop
top = np.random.randint(0, h - img_size[0])
left = np.random.randint(0, w - img_size[1])
cropped_img = random_crop(img, top, left, img_size)
cropped_seg = random_crop(seg, top, left, img_size)
# Horizontal Flip
if np.random.rand() > 0.5 and train:
cropped_img = cropped_img[:, ::-1, :]
cropped_seg = cropped_seg[:, ::-1, :]
# Vertical Flip
if np.random.rand() > 0.5 and train:
cropped_img = cropped_img[::-1, :, :]
cropped_seg = cropped_seg[::-1, :, :]
# Noise
if train:
noise = 0.001 * np.random.randn(*cropped_img.shape)
cropped_img += noise
imgs.append(cropped_img)
segs.append(cropped_seg)
if len(imgs) == batch_size:
imgs_temp = np.array(imgs)
segs_temp = np.array(segs)
imgs = []
segs = []
if sample_weights is not None:
yield (imgs_temp, segs_temp, sample_weights)
yield (imgs_temp, segs_temp)
def decode_img(x):
x *= 255
x = x.astype(np.uint8)
return x
def decode_onehot(y):
y = decode_img(y)
zero_channel = np.zeros((*y.shape[:-1], 1), dtype=np.uint8)
y = np.concatenate((y, zero_channel), axis=-1)
y[:, :, :, 1] = 0
return y
def convert_prob_into_onehot(x):
t = tf.constant(value=x)
y = tf.one_hot(tf.argmax(t, dimension = -1), depth = 2)
return y.eval()
def weighted_crossentropy_wrapper(class_weights):
def weighted_cross_entropy(onehot_labels, output):
'''
A quick wrapper to compute weighted cross entropy.
------------------
Technical Details
------------------
The class_weights list can be multiplied by onehot_labels directly because the last dimension
of onehot_labels is 12 and class_weights (length 12) can broadcast across that dimension, which is what we want.
Then we collapse the last dimension for the class_weights to get a shape of (batch_size, height, width, 1)
to get a mask with each pixel's value representing the class_weight.
This mask can then be that can be broadcasted to the intermediate output of logits
and onehot_labels when calculating the cross entropy loss.
------------------
INPUTS:
- onehot_labels(Tensor): the one-hot encoded labels of shape (batch_size, height, width, num_classes)
- logits(Tensor): the logits output from the model that is of shape (batch_size, height, width, num_classes)
- class_weights(list): A list where each index is the class label and the value of the index is the class weight.
OUTPUTS:
- loss(Tensor): a scalar Tensor that is the weighted cross entropy loss output.
'''
# weights = onehot_labels * class_weights + (1 - onehot_labels)
# weights = tf.reduce_sum(weights, 3)
# logits = convert_to_logits(prob)
loss = -tf.reduce_mean(onehot_labels * weights * tf.log(output) + 1e-9)
# loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits, weights=weights)
# loss = tf.reduce_mean(loss_batches)
return loss
return weighted_cross_entropy
class ImageWriter(Callback):
def __init__(self, img_shape, batch_size):
super().__init__()
self.batch_size = batch_size
test_gen = get_datagen('img91.png', 'seg91.png', train=False, batch_size=batch_size, img_size=img_shape)
self.x, self.y = test_gen.__next__()
self.y_shape = (self.x.shape[1], self.x.shape[2], 2)
self.img = decode_img(self.x)
self.gth = decode_onehot(self.y)
self.preds = []
def on_epoch_end(self, epoch, logs={}):
self.p = self.model.predict_on_batch(self.x)
self.pre = decode_onehot(self.p)
self.preds.append(self.pre)
figsize = (
(self.x.shape[2] * (len(self.preds) + 1)) / 100,
(self.x.shape[1] * (self.batch_size + 1)) / 100
)
fig, axes = plt.subplots(self.batch_size, 2 + len(self.preds), figsize=figsize)
# Set title
axes[0, 0].set_title('X')
axes[0, 1].set_title('GT')
for i in range(len(self.preds)):
axes[0, i + 2].set_title(str(i))
# Set images
for i in range(self.batch_size):
axes[i, 0].imshow(self.img[i], vmin=0, vmax=255)
axes[i, 0].axis('off')
axes[i, 1].imshow(self.gth[i], vmin=0, vmax=255)
axes[i, 1].axis('off')
for j in range(len(self.preds)):
axes[i, j + 2].imshow(self.preds[j][i], vmin=0, vmax=255)
axes[i, j + 2].axis('off')
plt.savefig('history.jpg'.format(epoch))
if __name__ =='__main__':
img_shape = (256, 256, 3)
steps_per_epoch = 128
validation_steps = 4
epochs = 50
batch_size = 16
weight_decay_l2 = 0.01
train_gen = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=batch_size, sample_weights=None, train=True)
test_gen = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=batch_size, sample_weights=None, train=False)
# Calculate class weights
_, y = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=1024, sample_weights=None, train=False).__next__()
pixcount = np.count_nonzero(y, axis=(0,1,2))
imgcount = np.count_nonzero(np.count_nonzero(y, axis=(1, 2)), axis=0)
freq = pixcount / imgcount
weights = 1. / freq
weights /= weights.sum()
print('weights =', weights)
# model = vgg19_unet(input_shape=img_shape, classes=2, weight_decay=weight_decay_l2)
model = Pix2Pix(input_shape=img_shape, classes=2).build()
# model = SegNet(input_shape=img_shape, classes=2)
# model = UNetMatlab(input_shape=img_shape, classes=2).build()
model.compile(
# optimizer=optimizers.SGD(lr=5e-2, momentum=0.9, clipnorm=0.05),
optimizer=optimizers.Adam(lr=1e-4, clipnorm=0.05),
loss=weighted_crossentropy_wrapper(weights),
metrics=['accuracy']
)
model.summary()
mc_cb = ModelCheckpoint('model.h5', monitor='val_loss')
im_cb = ImageWriter(img_shape, 32)
history = model.fit_generator(
generator=train_gen,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[mc_cb, im_cb],
validation_data=test_gen,
validation_steps=validation_steps,
shuffle=True,
use_multiprocessing=True
)
import numpy as np
import os
import skimage.io as io
import skimage.transform as trans
import numpy as np
from keras.engine import InputSpec
from keras import initializers, regularizers
from keras.layers import Input, Concatenate, BatchNormalization, Activation, MaxPooling2D, Dropout, Conv2DTranspose
from keras.layers.advanced_activations import LeakyReLU, ReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model
import keras.backend as K
import tensorflow as tf
class UNetMatlab:
""" https://jp.mathworks.com/help/images/multispectral-semantic-segmentation-using-deep-learning.html?lang=en """
def __init__(self, input_shape, classes, l2reg=0.0001):
self.input_shape = input_shape
self.classes = classes
self.l2reg = l2reg
def build(self):
x = Input(shape=self.input_shape)
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(x)
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
d1 = h
h = MaxPooling2D(2)(h)
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu')(h)
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
d2 = h
h = MaxPooling2D(2)(h)
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
d3 = h
h = MaxPooling2D(2)(h)
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
d4 = h
h = Dropout(0.5)(h)
h = MaxPooling2D(2)(h)
h = Conv2D(1024, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Conv2D(1024, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Dropout(0.5)(h)
h = Conv2DTranspose(512, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Concatenate(axis=-1)([h, d4])
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2DTranspose(256, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Concatenate(axis=-1)([h, d3])
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2DTranspose(128, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Concatenate(axis=-1)([h, d2])
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2DTranspose(64, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Concatenate(axis=-1)([h, d1])
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
logit = Conv2D(self.classes, kernel_size=1, padding='valid', kernel_initializer='he_normal')(h)
prob = Activation('softmax')(logit)
model = Model(x, prob)
return model
""" https://github.com/eriklindernoren/Keras-GAN/blob/master/pix2pix/pix2pix.py """
class Pix2Pix:
def __init__(self, input_shape, classes):
self.input_shape = input_shape
self.classes = classes
def build(self):
def conv(layer_input, filters):
"""Layers used during downsampling"""
d = ConvSN2D(filters, kernel_size=3, strides=1, dilation_rate=2, padding='same')(layer_input)
d = BatchNormalization(momentum=0.9)(d)
d = LeakyReLU(alpha=0.2)(d)
d = ConvSN2D(filters, kernel_size=3, strides=1, dilation_rate=2, padding='same')(d)
d = BatchNormalization(momentum=0.9)(d)
d = LeakyReLU(alpha=0.2)(d)
pooled = MaxPooling2D(2)(d)
return pooled, d
def deconv(layer_input, skip_input, filters):
"""Layers used during upsampling"""
u = UpSampling2D(size=2)(layer_input)
u = Concatenate(axis=-1)([u, skip_input])
u = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(u)
u = BatchNormalization(momentum=0.9)(u)
u = LeakyReLU(alpha=0.2)(u)
u = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(u)
u = BatchNormalization(momentum=0.9)(u)
u = LeakyReLU(alpha=0.2)(u)
return u
def res(layer_input, filters):
x = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
x = BatchNormalization(momentum=0.9)(x)
x = LeakyReLU(alpha=0.2)(x)
x = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization(momentum=0.9)(x)
x = LeakyReLU(alpha=0.2)(x)
return x
x = Input(shape=self.input_shape)
p1, d1 = conv(x, 64)
p2, d2 = conv(p1, 128)
p3, d3 = conv(p2, 256)
p4, d4 = conv(p3, 512)
p5, d5 = conv(p4, 512)
p6, d6 = conv(p5, 512)
p7, d7 = conv(p6, 1024)
z = res(p7, 1024)
u1 = deconv(z, d7, 512)
u2 = deconv(u1, d6, 512)
u3 = deconv(u2, d5, 512)
u4 = deconv(u3, d4, 256)
u5 = deconv(u4, d3, 128)
u6 = deconv(u5, d2, 64)
u7 = deconv(u6, d1, 64)
logit = ConvSN2D(self.classes, kernel_size=1)(u7)
prob = Activation('softmax')(logit)
return Model(x, prob)
def vgg19_unet(input_shape, weight_decay=0., classes=2):
# Image Input
img = Input(shape=input_shape, name='image')
# Block 1
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer=regularizers.l2(weight_decay))(img)
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv1)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(conv1)
# Block 2
conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool1)
conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv2)
conv2 = BatchNormalization()(conv2)
pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(conv2)
# Block 3
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool2)
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv3)
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv3)
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv3)
conv3 = BatchNormalization()(conv3)
pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(conv3)
# Block 4
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool3)
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv4)
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv4)
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv4)
conv4 = BatchNormalization()(conv4)
pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(conv4)
# Block 5
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv5)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv5)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv5)
conv5 = BatchNormalization()(conv5)
up6 = UpSampling2D(2)(conv5)
up6 = Concatenate(axis=-1)([up6, conv4])
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv1')(up6)
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv2')(conv6)
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv3')(conv6)
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv4')(conv6)
conv6 = BatchNormalization()(conv6)
up7 = UpSampling2D(2)(conv6)
up7 = Concatenate(axis=-1)([up7, conv3])
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv1')(up7)
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv2')(conv7)
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv3')(conv7)
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv4')(conv7)
conv7 = BatchNormalization()(conv7)
up8 = UpSampling2D(2)(conv7)
up8 = Concatenate(axis=-1)([up8, conv2])
conv8 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block8_conv1')(up8)
conv8 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block8_conv2')(conv8)
conv8 = BatchNormalization()(conv8)
up9 = UpSampling2D(2)(conv8)
up9 = Concatenate(axis=-1)([up9, conv1])
conv9 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block9_conv1')(up9)
conv9 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block9_conv2')(conv9)
conv9 = BatchNormalization()(conv9)
output = Conv2D(classes, (1, 1), padding='same', activation='softmax', name="prob")(conv9)
model = Model(inputs=img, outputs=output)
from keras.regularizers import l1, l2
from keras.applications.vgg19 import VGG19
weights_path = 'temp_vgg19_notop.h5'
VGG19(input_shape=input_shape, weights='imagenet', include_top=False).save_weights(weights_path)
model.load_weights(weights_path, by_name=True)
import os; os.remove('temp_vgg19_notop.h5')
return model
def SegNet(input_shape=(360, 480, 3), classes=12):
### @ https://github.com/alexgkendall/SegNet-Tutorial/blob/master/Example_Models/bayesian_segnet_camvid.prototxt
img_input = Input(shape=input_shape)
x = img_input
# Encoder
x = Conv2D(64, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(128, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(256, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(512, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
# Decoder
x = Conv2D(512, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(256, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(128, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(64, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(classes, (1, 1), padding="valid")(x)
x = Activation("softmax")(x)
model = Model(img_input, x)
return model
""" https://github.com/IShengFang/SpectralNormalizationKeras/blob/master/SpectralNormalizationKeras.py """
class ConvSN2D(Conv2D):
def build(self, input_shape):
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.kernel = self.add_weight(shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
initializer=initializers.RandomNormal(0, 1),
name='sn',
trainable=False)
# Set input spec.
self.input_spec = InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
self.built = True
def call(self, inputs, training=None):
def _l2normalize(v, eps=1e-12):
return v / (K.sum(v ** 2) ** 0.5 + eps)
def power_iteration(W, u):
#Accroding the paper, we only need to do power iteration one time.
_u = u
_v = _l2normalize(K.dot(_u, K.transpose(W)))
_u = _l2normalize(K.dot(_v, W))
return _u, _v
#Spectral Normalization
W_shape = self.kernel.shape.as_list()
#Flatten the Tensor
W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
_u, _v = power_iteration(W_reshaped, self.u)
#Calculate Sigma
sigma=K.dot(_v, W_reshaped)
sigma=K.dot(sigma, K.transpose(_u))
#normalize it
W_bar = W_reshaped / sigma
#reshape weight tensor
if training in {0, False}:
W_bar = K.reshape(W_bar, W_shape)
else:
with tf.control_dependencies([self.u.assign(_u)]):
W_bar = K.reshape(W_bar, W_shape)
outputs = K.conv2d(
inputs,
W_bar,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self.use_bias:
outputs = K.bias_add(
outputs,
self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment