Skip to content

Instantly share code, notes, and snippets.

@neelriyer
Created August 10, 2020 00:26
Show Gist options
  • Save neelriyer/a4093fc9dac010f6fb7a36a9e43df7ea to your computer and use it in GitHub Desktop.
Save neelriyer/a4093fc9dac010f6fb7a36a9e43df7ea to your computer and use it in GitHub Desktop.
Run neural style transfer for audio
import torch
import torch.nn as nn
from torch.nn import Conv2d, ReLU, AvgPool1d, MaxPool2d, Linear, Conv1d
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import os
import torchvision.transforms as transforms
import gc; gc.collect()
input_float = content_float.clone()
#input_float = Variable(torch.randn(content_float.size())).type(torch.FloatTensor)
learning_rate_initial = 1e-4
def get_input_param_optimizer(input_float):
input_param = nn.Parameter(input_float.data)
# optimizer = optim.Adagrad([input_param], lr=learning_rate_initial, lr_decay=0.0001,weight_decay=0)
optimizer = optim.Adam([input_param], lr=learning_rate_initial)
# optimizer = optim.SGD([input_param], lr=learning_rate_initial)
# optimizer = optim.RMSprop([input_param], lr=learning_rate_initial)
return input_param, optimizer
num_steps= 10000
# from https://pytorch.org/tutorials/advanced/neural_style_tutorial.html
def run_style_transfer(cnn, style_float=style_float,\
content_float=content_float,\
input_float=input_float,\
num_steps=num_steps, style_weight=style_weight):
print('Building the style transfer model..')
# model, style_losses = get_style_model_and_losses(cnn, style_float)
model, style_losses, content_losses = get_style_model_and_losses(cnn, style_float, content_float)
input_param, optimizer = get_input_param_optimizer(input_float)
print('Optimizing..')
run = [0]
while run[0] <= num_steps:
def closure():
# correct the values of updated input image
input_param.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_param)
style_score = 0
content_score = 0
for sl in style_losses:
#print('sl is ',sl,' style loss is ',style_score)
style_score += sl.loss
for cl in content_losses:
content_score += cl.loss
style_score *= style_weight
content_score *= content_weight
loss = style_score + content_score
loss.backward()
run[0] += 1
if run[0] % 100 == 0:
print("run {}:".format(run))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
style_score.item(), content_score.item()))
print()
return style_score + content_score
optimizer.step(closure)
# ensure values are between 0 and 1
input_param.data.clamp_(0, 1)
return input_param.data
output = run_style_transfer(cnn, style_float=style_float, content_float=content_float, input_float=input_float)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment