Skip to content

Instantly share code, notes, and snippets.

@wcneill
Last active July 12, 2020 20:12
Show Gist options
  • Save wcneill/c706210d8b89f91e5fef825d94187bee to your computer and use it in GitHub Desktop.
Save wcneill/c706210d8b89f91e5fef825d94187bee to your computer and use it in GitHub Desktop.
Putting it all together
# load in content and style image, and create target image by copying content image
content = load_image('data/style/clouds-19.jpg').to(device)
style = load_image('data/style/abstract-art-freedom.jpg', shape=content.shape[-2:]).to(device)
target = content.clone().requires_grad_(True).to(device)
style_weights = {'conv1_1': .2,
'conv2_1': .2,
'conv3_1': .2,
'conv4_1': .2,
'conv5_1': .2}
show = 1000
steps = 5000
alpha = 1
beta = 1e3
# get network layer output for style and content templates.
s_features = get_features(vgg, style)
c_features = get_features(vgg, content)
# compute Gramian matrices from style template features
s_grams = {layer: gramian(features) for layer, features in s_features.items()}
# set optimizer to update target image pixels
opt = optim.Adam([target], lr=0.001)
# Perform back propogation on target image pixels
for step in range(1, steps + 1):
opt.zero_grad()
# get target image feature channels
t_features = get_features(vgg, target)
# get content and style loss
c_loss = content_loss(c_features, t_features)
s_loss = style_loss(s_grams, t_features, style_weights)
#compute total loss
total_loss = c_weight * c_loss + s_weight * s_loss
#compute gradient descent and update target image pixels.
total_loss.backward()
opt.step()
# occasionally show updated target image.
if step % show == 0:
print('Total loss: ', total_loss.item())
plt.imshow(im_convert(target))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment