Last active
July 12, 2020 20:12
-
-
Save wcneill/c706210d8b89f91e5fef825d94187bee to your computer and use it in GitHub Desktop.
Putting it all together
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# load in content and style image, and create target image by copying content image | |
content = load_image('data/style/clouds-19.jpg').to(device) | |
style = load_image('data/style/abstract-art-freedom.jpg', shape=content.shape[-2:]).to(device) | |
target = content.clone().requires_grad_(True).to(device) | |
style_weights = {'conv1_1': .2, | |
'conv2_1': .2, | |
'conv3_1': .2, | |
'conv4_1': .2, | |
'conv5_1': .2} | |
show = 1000 | |
steps = 5000 | |
alpha = 1 | |
beta = 1e3 | |
# get network layer output for style and content templates. | |
s_features = get_features(vgg, style) | |
c_features = get_features(vgg, content) | |
# compute Gramian matrices from style template features | |
s_grams = {layer: gramian(features) for layer, features in s_features.items()} | |
# set optimizer to update target image pixels | |
opt = optim.Adam([target], lr=0.001) | |
# Perform back propogation on target image pixels | |
for step in range(1, steps + 1): | |
opt.zero_grad() | |
# get target image feature channels | |
t_features = get_features(vgg, target) | |
# get content and style loss | |
c_loss = content_loss(c_features, t_features) | |
s_loss = style_loss(s_grams, t_features, style_weights) | |
#compute total loss | |
total_loss = c_weight * c_loss + s_weight * s_loss | |
#compute gradient descent and update target image pixels. | |
total_loss.backward() | |
opt.step() | |
# occasionally show updated target image. | |
if step % show == 0: | |
print('Total loss: ', total_loss.item()) | |
plt.imshow(im_convert(target)) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment