Skip to content

Instantly share code, notes, and snippets.

@ArnoutDevos
Created January 8, 2018 11:41
Show Gist options
  • Save ArnoutDevos/0cb0328aa09633d0abb057de7362234d to your computer and use it in GitHub Desktop.
Save ArnoutDevos/0cb0328aa09633d0abb057de7362234d to your computer and use it in GitHub Desktop.
STYLE_LAYERS = [
('conv1_1', 0.5),
('conv2_1', 0.5),
('conv3_1', 0.5),
('conv4_1', 0.5),
('conv5_1', 0.5),
]
def style_loss_func(sess, model):
def _gram_matrix(feat):
tensor = feat
shape = tensor.get_shape()
num_channels = int(shape[3])
matrix = tf.reshape(tensor, shape=[-1, num_channels])
gram = tf.matmul(tf.transpose(matrix), matrix)
return gram
def _style_loss(current_feat, style_feat):
H = current_feat.shape[1]
W = current_feat.shape[2]
M = H * W
N = current_feat.shape[3]
current_feat = tf.convert_to_tensor(current_feat)
gram_current = _gram_matrix(current_feat)
gram_style = _gram_matrix(style_feat)
loss = 1/(4 * (N ** 2) * (M ** 2)) * tf.reduce_sum(tf.square(gram_current - gram_style))
return loss
E = [_style_loss(sess.run(model[layer_name]), model[layer_name]) for layer_name, _ in STYLE_LAYERS]
W = [w for _, w in STYLE_LAYERS]
loss = sum([W[l] * E[l] for l in range(len(STYLE_LAYERS))])
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment