Created
June 24, 2021 06:44
-
-
Save Yuvnish017/044f128974a325d9cf9a4bc0df8e8f32 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def compute_loss(combination_image, base_image, style_reference_image): | |
input_tensor = tf.concat( | |
[base_image, style_reference_image, combination_image], axis=0 | |
) | |
features = feature_extractor(input_tensor) | |
# Initialize the loss | |
loss = tf.zeros(shape=()) | |
# Add content loss | |
layer_features = features[content_layer_name] | |
base_image_features = layer_features[0, :, :, :] | |
combination_features = layer_features[2, :, :, :] | |
loss = loss + content_weight * content_loss( | |
base_image_features, combination_features | |
) | |
# Add style loss | |
for layer_name in style_layer_names: | |
layer_features = features[layer_name] | |
style_reference_features = layer_features[1, :, :, :] | |
combination_features = layer_features[2, :, :, :] | |
sl = style_loss(style_reference_features, combination_features) | |
loss += (style_weight / len(style_layer_names)) * sl | |
# Add total variation loss | |
loss += total_variation_weight * total_variation_loss(combination_image) | |
return loss | |
@tf.function | |
def compute_loss_and_grads(combination_image, base_image, style_reference_image): | |
with tf.GradientTape() as tape: | |
loss = compute_loss(combination_image, base_image, style_reference_image) | |
grads = tape.gradient(loss, combination_image) | |
return loss, grads |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment