Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active September 28, 2019 17:41
Show Gist options
  • Save ProGamerGov/de7a8734e05011018d535385de31b034 to your computer and use it in GitHub Desktop.
Save ProGamerGov/de7a8734e05011018d535385de31b034 to your computer and use it in GitHub Desktop.

256:

ubuntu@ip-Address:~/neural-style-pt$ python3 neural_style.py -num_iterations 500 -gpu 0,c -backend cudnn -seed 876 -multidevice_strategy 0 -image_size 256

File: neural_style.py
Function: main at line 56

Line # Max usage   Peak usage diff max diff peak  Line Contents
===============================================================
    56                                           @profile
    57                                           def main():
    58     0.00B        0.00B -144.11M -186.00M      dtype, multidevice, backward_device = setup_gpu()
    59
    60    76.39M      574.00M   76.39M  574.00M      cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, False)
    61
    62    76.95M      102.00M  576.00K -472.00M      content_image = preprocess(params.content_image, params.image_size).type(dtype)
    63    76.95M      102.00M    0.00B    0.00B      style_image_input = params.style_image.split(',')
    64    76.95M      102.00M    0.00B    0.00B      style_image_list, ext = [], [".jpg",".png"]
    65    76.95M      102.00M    0.00B    0.00B      for image in style_image_input:
    66    76.95M      102.00M    0.00B    0.00B          if os.path.isdir(image):
    67                                                       images = (image + "/" + file for file in os.listdir(image)
    68                                                       if os.path.splitext(file)[1].lower() in ext)
    69                                                       style_image_list.extend(images)
    70                                                   else:
    71    76.95M      102.00M    0.00B    0.00B              style_image_list.append(image)
    72    76.95M      102.00M    0.00B    0.00B      style_images_caffe = []
    73    77.52M      104.00M  588.00K    2.00M      for image in style_image_list:
    74    76.95M      102.00M -588.00K   -2.00M          style_size = int(params.image_size * params.style_scale)
    75    77.52M      104.00M  588.00K    2.00M          img_caffe = preprocess(image, style_size).type(dtype)
    76    77.52M      104.00M    0.00B    0.00B          style_images_caffe.append(img_caffe)
    77
    78    77.52M      104.00M    0.00B    0.00B      if params.init_image != None:
    79                                                   image_size = (content_image.size(2), content_image.size(3))
    80                                                   init_image = preprocess(params.init_image, image_size).type(dtype)
    81
    82                                               # Handle style blending weights for multiple style inputs
    83    77.52M      104.00M    0.00B    0.00B      style_blend_weights = []
    84    77.52M      104.00M    0.00B    0.00B      if params.style_blend_weights == None:
    85                                                   # Style blending not specified, so use equal weighting
    86    77.52M      104.00M    0.00B    0.00B          for i in style_image_list:
    87    77.52M      104.00M    0.00B    0.00B              style_blend_weights.append(1.0)
    88    77.52M      104.00M    0.00B    0.00B          for i, blend_weights in enumerate(style_blend_weights):
    89    77.52M      104.00M    0.00B    0.00B              style_blend_weights[i] = int(style_blend_weights[i])
    90                                               else:
    91                                                   style_blend_weights = params.style_blend_weights.split(',')
    92                                                   assert len(style_blend_weights) == len(style_image_list), \
    93                                                     "-style_blend_weights and -style_images must have the same number of elements!"
    94
    95                                               # Normalize the style blending weights so they sum to 1
    96    77.52M      104.00M    0.00B    0.00B      style_blend_sum = 0
    97    77.52M      104.00M    0.00B    0.00B      for i, blend_weights in enumerate(style_blend_weights):
    98    77.52M      104.00M    0.00B    0.00B          style_blend_weights[i] = float(style_blend_weights[i])
    99    77.52M      104.00M    0.00B    0.00B          style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
   100    77.52M      104.00M    0.00B    0.00B      for i, blend_weights in enumerate(style_blend_weights):
   101    77.52M      104.00M    0.00B    0.00B          style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
   102
   103    77.52M      104.00M    0.00B    0.00B      content_layers = params.content_layers.split(',')
   104    77.52M      104.00M    0.00B    0.00B      style_layers = params.style_layers.split(',')
   105
   106                                               # Set up the network, inserting style and content loss modules
   107    79.27M      184.00M    1.75M   80.00M      cnn = copy.deepcopy(cnn)
   108    79.27M      144.00M    0.00B  -40.00M      content_losses, style_losses, tv_losses = [], [], []
   109    79.27M      144.00M    0.00B    0.00B      next_content_idx, next_style_idx = 1, 1
   110    79.27M      144.00M    0.00B    0.00B      net = nn.Sequential()
   111    79.27M      144.00M    0.00B    0.00B      c, r = 0, 0
   112    79.27M      144.00M    0.00B    0.00B      if params.tv_weight > 0:
   113    79.27M      144.00M    0.00B    0.00B          tv_mod = TVLoss(params.tv_weight).type(dtype)
   114    79.27M      144.00M    0.00B    0.00B          net.add_module(str(len(net)), tv_mod)
   115    79.27M      144.00M    0.00B    0.00B          tv_losses.append(tv_mod)
   116
   117    79.27M      144.00M    0.00B    0.00B      for i, layer in enumerate(list(cnn), 1):
   118    79.27M      144.00M    0.00B    0.00B          if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
   119    79.27M      144.00M    0.00B    0.00B              if isinstance(layer, nn.Conv2d):
   120    79.27M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   121
   122    79.27M      144.00M    0.00B    0.00B                  if layerList['C'][c] in content_layers:
   123                                                               print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
   124                                                               loss_module = ContentLoss(params.content_weight)
   125                                                               net.add_module(str(len(net)), loss_module)
   126                                                               content_losses.append(loss_module)
   127
   128    79.27M      144.00M    0.00B    0.00B                  if layerList['C'][c] in style_layers:
   129                                                               print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
   130                                                               loss_module = StyleLoss(params.style_weight)
   131                                                               net.add_module(str(len(net)), loss_module)
   132                                                               style_losses.append(loss_module)
   133    79.27M      144.00M    0.00B    0.00B                  c+=1
   134
   135    79.27M      144.00M    0.00B    0.00B              if isinstance(layer, nn.ReLU):
   136    79.27M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   137
   138    79.27M      144.00M    0.00B    0.00B                  if layerList['R'][r] in content_layers:
   139    79.27M      144.00M    0.00B    0.00B                      print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
   140    79.27M      144.00M    0.00B    0.00B                      loss_module = ContentLoss(params.content_weight)
   141    79.27M      144.00M    0.00B    0.00B                      net.add_module(str(len(net)), loss_module)
   142    79.27M      144.00M    0.00B    0.00B                      content_losses.append(loss_module)
   143    79.27M      144.00M    0.00B    0.00B                      next_content_idx += 1
   144
   145    79.27M      144.00M    0.00B    0.00B                  if layerList['R'][r] in style_layers:
   146    79.27M      144.00M    0.00B    0.00B                      print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
   147    79.27M      144.00M    0.00B    0.00B                      loss_module = StyleLoss(params.style_weight)
   148    79.27M      144.00M    0.00B    0.00B                      net.add_module(str(len(net)), loss_module)
   149    79.27M      144.00M    0.00B    0.00B                      style_losses.append(loss_module)
   150    79.27M      144.00M    0.00B    0.00B                      next_style_idx += 1
   151    79.27M      144.00M    0.00B    0.00B                  r+=1
   152
   153    79.27M      144.00M    0.00B    0.00B              if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
   154    79.27M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   155
   156    79.27M      144.00M    0.00B    0.00B      if multidevice:
   157    28.14M      144.00M  -51.13M    0.00B          net = setup_multi_device(net)
   158
   159                                               # Capture content targets
   160    28.14M       44.00M    0.00B -100.00M      for i in content_losses:
   161    28.14M       44.00M    0.00B    0.00B          i.mode = 'capture'
   162    28.14M       44.00M    0.00B    0.00B      print("Capturing content targets")
   163    28.14M       44.00M    0.00B    0.00B      print_torch(net, multidevice)
   164    29.26M       44.00M    1.12M    0.00B      net(content_image)
   165
   166                                               # Capture style targets
   167    29.26M       44.00M    0.00B    0.00B      for i in content_losses:
   168    29.26M       44.00M    0.00B    0.00B          i.mode = 'None'
   169
   170    29.29M       44.00M   24.00K    0.00B      for i, image in enumerate(style_images_caffe):
   171    29.26M       44.00M  -24.00K    0.00B          print("Capturing style target " + str(i+1))
   172    29.26M       44.00M    0.00B    0.00B          for j in style_losses:
   173    29.26M       44.00M    0.00B    0.00B              j.mode = 'capture'
   174    29.26M       44.00M    0.00B    0.00B              j.blend_weight = style_blend_weights[i]
   175    29.29M       46.00M   24.00K    2.00M          net(style_images_caffe[i])
   176
   177                                               # Set all loss modules to loss mode
   178    29.29M       44.00M    0.00B   -2.00M      for i in content_losses:
   179    29.29M       44.00M    0.00B    0.00B          i.mode = 'loss'
   180    29.29M       44.00M    0.00B    0.00B      for i in style_losses:
   181    29.29M       44.00M    0.00B    0.00B          i.mode = 'loss'
   182
   183                                               # Freeze the network in order to prevent
   184                                               # unnecessary gradient calculations
   185    29.29M       44.00M    0.00B    0.00B      for param in net.parameters():
   186                                                   param.requires_grad = False
   187
   188                                               # Initialize the image
   189    29.29M       44.00M    0.00B    0.00B      if params.seed >= 0:
   190    29.29M       44.00M    0.00B    0.00B          torch.manual_seed(params.seed)
   191    29.29M       44.00M    0.00B    0.00B          torch.cuda.manual_seed(params.seed)
   192    29.29M       44.00M    0.00B    0.00B          torch.cuda.manual_seed_all(params.seed)
   193    29.29M       44.00M    0.00B    0.00B          torch.backends.cudnn.deterministic=True
   194    29.29M       44.00M    0.00B    0.00B      if params.init == 'random':
   195    29.29M       44.00M    0.00B    0.00B          B, C, H, W = content_image.size()
   196    29.85M       46.00M  576.00K    2.00M          img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
   197                                               elif params.init == 'image':
   198                                                   if params.init_image != None:
   199                                                       img = init_image.clone()
   200                                                   else:
   201                                                       img = content_image.clone()
   202    29.85M       46.00M    0.00B    0.00B      img = nn.Parameter(img.type(dtype))
   203
   204    29.85M       46.00M    0.00B    0.00B      def maybe_print(t, loss):
   205                                                   if params.print_iter > 0 and t % params.print_iter == 0:
   206                                                       print("Iteration " + str(t) + " / "+ str(params.num_iterations))
   207                                                       for i, loss_module in enumerate(content_losses):
   208                                                           print("  Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
   209                                                       for i, loss_module in enumerate(style_losses):
   210                                                           print("  Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
   211                                                       print("  Total loss: " + str(loss.item()))
   212
   213    29.85M       46.00M    0.00B    0.00B      def maybe_save(t):
   214                                                   should_save = params.save_iter > 0 and t % params.save_iter == 0
   215                                                   should_save = should_save or t == params.num_iterations
   216                                                   if should_save:
   217                                                       output_filename, file_extension = os.path.splitext(params.output_image)
   218                                                       if t == params.num_iterations:
   219                                                           filename = output_filename + str(file_extension)
   220                                                       else:
   221                                                           filename = str(output_filename) + "_" + str(t) + str(file_extension)
   222                                                       disp = deprocess(img.clone())
   223
   224                                                       # Maybe perform postprocessing for color-independent style transfer
   225                                                       if params.original_colors == 1:
   226                                                           disp = original_colors(deprocess(content_image.clone()), disp)
   227
   228                                                       disp.save(str(filename))
   229
   230                                               # Function to evaluate loss and gradient. We run the net forward and
   231                                               # backward to get the gradient, and sum up losses from the loss modules.
   232                                               # optim.lbfgs internally handles iteration and calls this function many
   233                                               # times, so we manually count the number of iterations to handle printing
   234                                               # and saving intermediate results.
   235    29.85M       46.00M    0.00B    0.00B      num_calls = [0]
   236    29.85M       46.00M    0.00B    0.00B      def feval():
   237                                                   num_calls[0] += 1
   238                                                   optimizer.zero_grad()
   239                                                   net(img)
   240                                                   loss = 0
   241
   242                                                   for mod in content_losses:
   243                                                       loss += mod.loss.to(backward_device)
   244                                                   for mod in style_losses:
   245                                                       loss += mod.loss.to(backward_device)
   246                                                   if params.tv_weight > 0:
   247                                                       for mod in tv_losses:
   248                                                           loss += mod.loss.to(backward_device)
   249
   250                                                   loss.backward()
   251
   252                                                   maybe_save(num_calls[0])
   253                                                   maybe_print(num_calls[0], loss)
   254
   255                                                   return loss
   256
   257    29.85M       46.00M   29.85M   46.00M      optimizer, loopVal = setup_optimizer(img)
   258   144.11M      184.00M  114.26M  138.00M      while num_calls[0] <= loopVal:
   259   144.11M      186.00M    0.00B    2.00M           optimizer.step(feval)

nvidia-smi:

nvidia-smi: 519 MiB

Test 2:

ubuntu@ip-Address:~/neural-style-pt$ python3 neural_style.py -num_iterations 500 -gpu 0 -backend cudnn -seed 876 -image_size 256

File: neural_style.py
Function: main at line 56

Line # Max usage   Peak usage diff max diff peak  Line Contents
===============================================================
    56                                           @profile
    57                                           def main():
    58     0.00B        0.00B -201.90M -376.00M      dtype, multidevice, backward_device = setup_gpu()
    59
    60    76.39M      574.00M   76.39M  574.00M      cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, False)
    61
    62    76.95M      102.00M  576.00K -472.00M      content_image = preprocess(params.content_image, params.image_size).type(dtype)
    63    76.95M      102.00M    0.00B    0.00B      style_image_input = params.style_image.split(',')
    64    76.95M      102.00M    0.00B    0.00B      style_image_list, ext = [], [".jpg",".png"]
    65    76.95M      102.00M    0.00B    0.00B      for image in style_image_input:
    66    76.95M      102.00M    0.00B    0.00B          if os.path.isdir(image):
    67                                                       images = (image + "/" + file for file in os.listdir(image)
    68                                                       if os.path.splitext(file)[1].lower() in ext)
    69                                                       style_image_list.extend(images)
    70                                                   else:
    71    76.95M      102.00M    0.00B    0.00B              style_image_list.append(image)
    72    76.95M      102.00M    0.00B    0.00B      style_images_caffe = []
    73    77.52M      104.00M  588.00K    2.00M      for image in style_image_list:
    74    76.95M      102.00M -588.00K   -2.00M          style_size = int(params.image_size * params.style_scale)
    75    77.52M      104.00M  588.00K    2.00M          img_caffe = preprocess(image, style_size).type(dtype)
    76    77.52M      104.00M    0.00B    0.00B          style_images_caffe.append(img_caffe)
    77
    78    77.52M      104.00M    0.00B    0.00B      if params.init_image != None:
    79                                                   image_size = (content_image.size(2), content_image.size(3))
    80                                                   init_image = preprocess(params.init_image, image_size).type(dtype)
    81
    82                                               # Handle style blending weights for multiple style inputs
    83    77.52M      104.00M    0.00B    0.00B      style_blend_weights = []
    84    77.52M      104.00M    0.00B    0.00B      if params.style_blend_weights == None:
    85                                                   # Style blending not specified, so use equal weighting
    86    77.52M      104.00M    0.00B    0.00B          for i in style_image_list:
    87    77.52M      104.00M    0.00B    0.00B              style_blend_weights.append(1.0)
    88    77.52M      104.00M    0.00B    0.00B          for i, blend_weights in enumerate(style_blend_weights):
    89    77.52M      104.00M    0.00B    0.00B              style_blend_weights[i] = int(style_blend_weights[i])
    90                                               else:
    91                                                   style_blend_weights = params.style_blend_weights.split(',')
    92                                                   assert len(style_blend_weights) == len(style_image_list), \
    93                                                     "-style_blend_weights and -style_images must have the same number of elements!"
    94
    95                                               # Normalize the style blending weights so they sum to 1
    96    77.52M      104.00M    0.00B    0.00B      style_blend_sum = 0
    97    77.52M      104.00M    0.00B    0.00B      for i, blend_weights in enumerate(style_blend_weights):
    98    77.52M      104.00M    0.00B    0.00B          style_blend_weights[i] = float(style_blend_weights[i])
    99    77.52M      104.00M    0.00B    0.00B          style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
   100    77.52M      104.00M    0.00B    0.00B      for i, blend_weights in enumerate(style_blend_weights):
   101    77.52M      104.00M    0.00B    0.00B          style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
   102
   103    77.52M      104.00M    0.00B    0.00B      content_layers = params.content_layers.split(',')
   104    77.52M      104.00M    0.00B    0.00B      style_layers = params.style_layers.split(',')
   105
   106                                               # Set up the network, inserting style and content loss modules
   107    79.27M      184.00M    1.75M   80.00M      cnn = copy.deepcopy(cnn)
   108    79.27M      144.00M    0.00B  -40.00M      content_losses, style_losses, tv_losses = [], [], []
   109    79.27M      144.00M    0.00B    0.00B      next_content_idx, next_style_idx = 1, 1
   110    79.27M      144.00M    0.00B    0.00B      net = nn.Sequential()
   111    79.27M      144.00M    0.00B    0.00B      c, r = 0, 0
   112    79.27M      144.00M    0.00B    0.00B      if params.tv_weight > 0:
   113    79.27M      144.00M    0.00B    0.00B          tv_mod = TVLoss(params.tv_weight).type(dtype)
   114    79.27M      144.00M    0.00B    0.00B          net.add_module(str(len(net)), tv_mod)
   115    79.27M      144.00M    0.00B    0.00B          tv_losses.append(tv_mod)
   116
   117    79.27M      144.00M    0.00B    0.00B      for i, layer in enumerate(list(cnn), 1):
   118    79.27M      144.00M    0.00B    0.00B          if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
   119    79.27M      144.00M    0.00B    0.00B              if isinstance(layer, nn.Conv2d):
   120    79.27M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   121
   122    79.27M      144.00M    0.00B    0.00B                  if layerList['C'][c] in content_layers:
   123                                                               print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
   124                                                               loss_module = ContentLoss(params.content_weight)
   125                                                               net.add_module(str(len(net)), loss_module)
   126                                                               content_losses.append(loss_module)
   127
   128    79.27M      144.00M    0.00B    0.00B                  if layerList['C'][c] in style_layers:
   129                                                               print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
   130                                                               loss_module = StyleLoss(params.style_weight)
   131                                                               net.add_module(str(len(net)), loss_module)
   132                                                               style_losses.append(loss_module)
   133    79.27M      144.00M    0.00B    0.00B                  c+=1
   134
   135    79.27M      144.00M    0.00B    0.00B              if isinstance(layer, nn.ReLU):
   136    79.27M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   137
   138    79.27M      144.00M    0.00B    0.00B                  if layerList['R'][r] in content_layers:
   139    79.27M      144.00M    0.00B    0.00B                      print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
   140    79.27M      144.00M    0.00B    0.00B                      loss_module = ContentLoss(params.content_weight)
   141    79.27M      144.00M    0.00B    0.00B                      net.add_module(str(len(net)), loss_module)
   142    79.27M      144.00M    0.00B    0.00B                      content_losses.append(loss_module)
   143    79.27M      144.00M    0.00B    0.00B                      next_content_idx += 1
   144
   145    79.27M      144.00M    0.00B    0.00B                  if layerList['R'][r] in style_layers:
   146    79.27M      144.00M    0.00B    0.00B                      print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
   147    79.27M      144.00M    0.00B    0.00B                      loss_module = StyleLoss(params.style_weight)
   148    79.27M      144.00M    0.00B    0.00B                      net.add_module(str(len(net)), loss_module)
   149    79.27M      144.00M    0.00B    0.00B                      style_losses.append(loss_module)
   150    79.27M      144.00M    0.00B    0.00B                      next_style_idx += 1
   151    79.27M      144.00M    0.00B    0.00B                  r+=1
   152
   153    79.27M      144.00M    0.00B    0.00B              if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
   154    79.27M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   155
   156    79.27M      144.00M    0.00B    0.00B      if multidevice:
   157                                                   net = setup_multi_device(net)
   158
   159                                               # Capture content targets
   160    79.27M      144.00M    0.00B    0.00B      for i in content_losses:
   161    79.27M      144.00M    0.00B    0.00B          i.mode = 'capture'
   162    79.27M      144.00M    0.00B    0.00B      print("Capturing content targets")
   163    79.27M      144.00M    0.00B    0.00B      print_torch(net, multidevice)
   164   156.85M      164.00M   77.57M   20.00M      net(content_image)
   165
   166                                               # Capture style targets
   167   156.85M      164.00M    0.00B    0.00B      for i in content_losses:
   168   156.85M      164.00M    0.00B    0.00B          i.mode = 'None'
   169
   170   160.64M      236.00M    3.79M   72.00M      for i, image in enumerate(style_images_caffe):
   171   156.85M      164.00M   -3.79M  -72.00M          print("Capturing style target " + str(i+1))
   172   156.85M      164.00M    0.00B    0.00B          for j in style_losses:
   173   156.85M      164.00M    0.00B    0.00B              j.mode = 'capture'
   174   156.85M      164.00M    0.00B    0.00B              j.blend_weight = style_blend_weights[i]
   175   160.64M      256.00M    3.79M   92.00M          net(style_images_caffe[i])
   176
   177                                               # Set all loss modules to loss mode
   178   160.64M      236.00M    0.00B  -20.00M      for i in content_losses:
   179   160.64M      236.00M    0.00B    0.00B          i.mode = 'loss'
   180   160.64M      236.00M    0.00B    0.00B      for i in style_losses:
   181   160.64M      236.00M    0.00B    0.00B          i.mode = 'loss'
   182
   183                                               # Freeze the network in order to prevent
   184                                               # unnecessary gradient calculations
   185   160.64M      236.00M    0.00B    0.00B      for param in net.parameters():
   186   160.64M      236.00M    0.00B    0.00B          param.requires_grad = False
   187
   188                                               # Initialize the image
   189   160.64M      236.00M    0.00B    0.00B      if params.seed >= 0:
   190   160.64M      236.00M    0.00B    0.00B          torch.manual_seed(params.seed)
   191   160.64M      236.00M    0.00B    0.00B          torch.cuda.manual_seed(params.seed)
   192   160.64M      236.00M    0.00B    0.00B          torch.cuda.manual_seed_all(params.seed)
   193   160.64M      236.00M    0.00B    0.00B          torch.backends.cudnn.deterministic=True
   194   160.64M      236.00M    0.00B    0.00B      if params.init == 'random':
   195   160.64M      236.00M    0.00B    0.00B          B, C, H, W = content_image.size()
   196   161.20M      236.00M  576.00K    0.00B          img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
   197                                               elif params.init == 'image':
   198                                                   if params.init_image != None:
   199                                                       img = init_image.clone()
   200                                                   else:
   201                                                       img = content_image.clone()
   202   161.20M      236.00M    0.00B    0.00B      img = nn.Parameter(img.type(dtype))
   203
   204   161.20M      236.00M    0.00B    0.00B      def maybe_print(t, loss):
   205                                                   if params.print_iter > 0 and t % params.print_iter == 0:
   206                                                       print("Iteration " + str(t) + " / "+ str(params.num_iterations))
   207                                                       for i, loss_module in enumerate(content_losses):
   208                                                           print("  Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
   209                                                       for i, loss_module in enumerate(style_losses):
   210                                                           print("  Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
   211                                                       print("  Total loss: " + str(loss.item()))
   212
   213   161.20M      236.00M    0.00B    0.00B      def maybe_save(t):
   214                                                   should_save = params.save_iter > 0 and t % params.save_iter == 0
   215                                                   should_save = should_save or t == params.num_iterations
   216                                                   if should_save:
   217                                                       output_filename, file_extension = os.path.splitext(params.output_image)
   218                                                       if t == params.num_iterations:
   219                                                           filename = output_filename + str(file_extension)
   220                                                       else:
   221                                                           filename = str(output_filename) + "_" + str(t) + str(file_extension)
   222                                                       disp = deprocess(img.clone())
   223
   224                                                       # Maybe perform postprocessing for color-independent style transfer
   225                                                       if params.original_colors == 1:
   226                                                           disp = original_colors(deprocess(content_image.clone()), disp)
   227
   228                                                       disp.save(str(filename))
   229
   230                                               # Function to evaluate loss and gradient. We run the net forward and
   231                                               # backward to get the gradient, and sum up losses from the loss modules.
   232                                               # optim.lbfgs internally handles iteration and calls this function many
   233                                               # times, so we manually count the number of iterations to handle printing
   234                                               # and saving intermediate results.
   235   161.20M      236.00M    0.00B    0.00B      num_calls = [0]
   236   161.20M      236.00M    0.00B    0.00B      def feval():
   237                                                   num_calls[0] += 1
   238                                                   optimizer.zero_grad()
   239                                                   net(img)
   240                                                   loss = 0
   241
   242                                                   for mod in content_losses:
   243                                                       loss += mod.loss.to(backward_device)
   244                                                   for mod in style_losses:
   245                                                       loss += mod.loss.to(backward_device)
   246                                                   if params.tv_weight > 0:
   247                                                       for mod in tv_losses:
   248                                                           loss += mod.loss.to(backward_device)
   249
   250                                                   loss.backward()
   251
   252                                                   maybe_save(num_calls[0])
   253                                                   maybe_print(num_calls[0], loss)
   254
   255                                                   return loss
   256
   257   161.20M      236.00M  161.20M  236.00M      optimizer, loopVal = setup_optimizer(img)
   258   201.90M      290.00M   40.70M   54.00M      while num_calls[0] <= loopVal:
   259   201.90M      376.00M    0.00B   86.00M           optimizer.step(feval)

nvidia-smi:

nvidia-smi: 711 MiB

512:

ubuntu@ip-Address:~/neural-style-pt$ python3 neural_style.py -num_iterations 500 -gpu 0,c -multidevice_strategy 0 -backend cudnn -seed 876 -image_size 512

File: neural_style.py
Function: main at line 56

Line # Max usage   Peak usage diff max diff peak  Line Contents
===============================================================
    56                                           @profile
    57                                           def main():
    58     0.00B        0.00B -496.02M -582.00M      dtype, multidevice, backward_device = setup_gpu()
    59
    60    76.39M      574.00M   76.39M  574.00M      cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, False)
    61
    62    78.64M      102.00M    2.25M -472.00M      content_image = preprocess(params.content_image, params.image_size).type(dtype)
    63    78.64M      102.00M    0.00B    0.00B      style_image_input = params.style_image.split(',')
    64    78.64M      102.00M    0.00B    0.00B      style_image_list, ext = [], [".jpg",".png"]
    65    78.64M      102.00M    0.00B    0.00B      for image in style_image_input:
    66    78.64M      102.00M    0.00B    0.00B          if os.path.isdir(image):
    67                                                       images = (image + "/" + file for file in os.listdir(image)
    68                                                       if os.path.splitext(file)[1].lower() in ext)
    69                                                       style_image_list.extend(images)
    70                                                   else:
    71    78.64M      102.00M    0.00B    0.00B              style_image_list.append(image)
    72    78.64M      102.00M    0.00B    0.00B      style_images_caffe = []
    73    80.94M      102.00M    2.30M    0.00B      for image in style_image_list:
    74    78.64M      102.00M   -2.30M    0.00B          style_size = int(params.image_size * params.style_scale)
    75    80.94M      102.00M    2.30M    0.00B          img_caffe = preprocess(image, style_size).type(dtype)
    76    80.94M      102.00M    0.00B    0.00B          style_images_caffe.append(img_caffe)
    77
    78    80.94M      102.00M    0.00B    0.00B      if params.init_image != None:
    79                                                   image_size = (content_image.size(2), content_image.size(3))
    80                                                   init_image = preprocess(params.init_image, image_size).type(dtype)
    81
    82                                               # Handle style blending weights for multiple style inputs
    83    80.94M      102.00M    0.00B    0.00B      style_blend_weights = []
    84    80.94M      102.00M    0.00B    0.00B      if params.style_blend_weights == None:
    85                                                   # Style blending not specified, so use equal weighting
    86    80.94M      102.00M    0.00B    0.00B          for i in style_image_list:
    87    80.94M      102.00M    0.00B    0.00B              style_blend_weights.append(1.0)
    88    80.94M      102.00M    0.00B    0.00B          for i, blend_weights in enumerate(style_blend_weights):
    89    80.94M      102.00M    0.00B    0.00B              style_blend_weights[i] = int(style_blend_weights[i])
    90                                               else:
    91                                                   style_blend_weights = params.style_blend_weights.split(',')
    92                                                   assert len(style_blend_weights) == len(style_image_list), \
    93                                                     "-style_blend_weights and -style_images must have the same number of elements!"
    94
    95                                               # Normalize the style blending weights so they sum to 1
    96    80.94M      102.00M    0.00B    0.00B      style_blend_sum = 0
    97    80.94M      102.00M    0.00B    0.00B      for i, blend_weights in enumerate(style_blend_weights):
    98    80.94M      102.00M    0.00B    0.00B          style_blend_weights[i] = float(style_blend_weights[i])
    99    80.94M      102.00M    0.00B    0.00B          style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
   100    80.94M      102.00M    0.00B    0.00B      for i, blend_weights in enumerate(style_blend_weights):
   101    80.94M      102.00M    0.00B    0.00B          style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
   102
   103    80.94M      102.00M    0.00B    0.00B      content_layers = params.content_layers.split(',')
   104    80.94M      102.00M    0.00B    0.00B      style_layers = params.style_layers.split(',')
   105
   106                                               # Set up the network, inserting style and content loss modules
   107    82.64M      184.00M    1.70M   82.00M      cnn = copy.deepcopy(cnn)
   108    82.64M      144.00M    0.00B  -40.00M      content_losses, style_losses, tv_losses = [], [], []
   109    82.64M      144.00M    0.00B    0.00B      next_content_idx, next_style_idx = 1, 1
   110    82.64M      144.00M    0.00B    0.00B      net = nn.Sequential()
   111    82.64M      144.00M    0.00B    0.00B      c, r = 0, 0
   112    82.64M      144.00M    0.00B    0.00B      if params.tv_weight > 0:
   113    82.64M      144.00M    0.00B    0.00B          tv_mod = TVLoss(params.tv_weight).type(dtype)
   114    82.64M      144.00M    0.00B    0.00B          net.add_module(str(len(net)), tv_mod)
   115    82.64M      144.00M    0.00B    0.00B          tv_losses.append(tv_mod)
   116
   117    82.64M      144.00M    0.00B    0.00B      for i, layer in enumerate(list(cnn), 1):
   118    82.64M      144.00M    0.00B    0.00B          if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
   119    82.64M      144.00M    0.00B    0.00B              if isinstance(layer, nn.Conv2d):
   120    82.64M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   121
   122    82.64M      144.00M    0.00B    0.00B                  if layerList['C'][c] in content_layers:
   123                                                               print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
   124                                                               loss_module = ContentLoss(params.content_weight)
   125                                                               net.add_module(str(len(net)), loss_module)
   126                                                               content_losses.append(loss_module)
   127
   128    82.64M      144.00M    0.00B    0.00B                  if layerList['C'][c] in style_layers:
   129                                                               print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
   130                                                               loss_module = StyleLoss(params.style_weight)
   131                                                               net.add_module(str(len(net)), loss_module)
   132                                                               style_losses.append(loss_module)
   133    82.64M      144.00M    0.00B    0.00B                  c+=1
   134
   135    82.64M      144.00M    0.00B    0.00B              if isinstance(layer, nn.ReLU):
   136    82.64M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   137
   138    82.64M      144.00M    0.00B    0.00B                  if layerList['R'][r] in content_layers:
   139    82.64M      144.00M    0.00B    0.00B                      print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
   140    82.64M      144.00M    0.00B    0.00B                      loss_module = ContentLoss(params.content_weight)
   141    82.64M      144.00M    0.00B    0.00B                      net.add_module(str(len(net)), loss_module)
   142    82.64M      144.00M    0.00B    0.00B                      content_losses.append(loss_module)
   143    82.64M      144.00M    0.00B    0.00B                      next_content_idx += 1
   144
   145    82.64M      144.00M    0.00B    0.00B                  if layerList['R'][r] in style_layers:
   146    82.64M      144.00M    0.00B    0.00B                      print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
   147    82.64M      144.00M    0.00B    0.00B                      loss_module = StyleLoss(params.style_weight)
   148    82.64M      144.00M    0.00B    0.00B                      net.add_module(str(len(net)), loss_module)
   149    82.64M      144.00M    0.00B    0.00B                      style_losses.append(loss_module)
   150    82.64M      144.00M    0.00B    0.00B                      next_style_idx += 1
   151    82.64M      144.00M    0.00B    0.00B                  r+=1
   152
   153    82.64M      144.00M    0.00B    0.00B              if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
   154    82.64M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   155
   156    82.64M      144.00M    0.00B    0.00B      if multidevice:
   157    31.56M      144.00M  -51.08M    0.00B          net = setup_multi_device(net)
   158
   159                                               # Capture content targets
   160    31.56M       62.00M    0.00B  -82.00M      for i in content_losses:
   161    31.56M       62.00M    0.00B    0.00B          i.mode = 'capture'
   162    31.56M       62.00M    0.00B    0.00B      print("Capturing content targets")
   163    31.56M       62.00M    0.00B    0.00B      print_torch(net, multidevice)
   164    36.88M       62.00M    5.32M    0.00B      net(content_image)
   165
   166                                               # Capture style targets
   167    36.88M       62.00M    0.00B    0.00B      for i in content_losses:
   168    36.88M       62.00M    0.00B    0.00B          i.mode = 'None'
   169
   170    36.93M       62.00M   54.00K    0.00B      for i, image in enumerate(style_images_caffe):
   171    36.88M       62.00M  -54.00K    0.00B          print("Capturing style target " + str(i+1))
   172    36.88M       62.00M    0.00B    0.00B          for j in style_losses:
   173    36.88M       62.00M    0.00B    0.00B              j.mode = 'capture'
   174    36.88M       62.00M    0.00B    0.00B              j.blend_weight = style_blend_weights[i]
   175    36.93M       62.00M   54.00K    0.00B          net(style_images_caffe[i])
   176
   177                                               # Set all loss modules to loss mode
   178    36.93M       62.00M    0.00B    0.00B      for i in content_losses:
   179    36.93M       62.00M    0.00B    0.00B          i.mode = 'loss'
   180    36.93M       62.00M    0.00B    0.00B      for i in style_losses:
   181    36.93M       62.00M    0.00B    0.00B          i.mode = 'loss'
   182
   183                                               # Freeze the network in order to prevent
   184                                               # unnecessary gradient calculations
   185    36.93M       62.00M    0.00B    0.00B      for param in net.parameters():
   186                                                   param.requires_grad = False
   187
   188                                               # Initialize the image
   189    36.93M       62.00M    0.00B    0.00B      if params.seed >= 0:
   190    36.93M       62.00M    0.00B    0.00B          torch.manual_seed(params.seed)
   191    36.93M       62.00M    0.00B    0.00B          torch.cuda.manual_seed(params.seed)
   192    36.93M       62.00M    0.00B    0.00B          torch.cuda.manual_seed_all(params.seed)
   193    36.93M       62.00M    0.00B    0.00B          torch.backends.cudnn.deterministic=True
   194    36.93M       62.00M    0.00B    0.00B      if params.init == 'random':
   195    36.93M       62.00M    0.00B    0.00B          B, C, H, W = content_image.size()
   196    39.18M       62.00M    2.25M    0.00B          img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
   197                                               elif params.init == 'image':
   198                                                   if params.init_image != None:
   199                                                       img = init_image.clone()
   200                                                   else:
   201                                                       img = content_image.clone()
   202    39.18M       62.00M    0.00B    0.00B      img = nn.Parameter(img.type(dtype))
   203
   204    39.18M       62.00M    0.00B    0.00B      def maybe_print(t, loss):
   205                                                   if params.print_iter > 0 and t % params.print_iter == 0:
   206                                                       print("Iteration " + str(t) + " / "+ str(params.num_iterations))
   207                                                       for i, loss_module in enumerate(content_losses):
   208                                                           print("  Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
   209                                                       for i, loss_module in enumerate(style_losses):
   210                                                           print("  Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
   211                                                       print("  Total loss: " + str(loss.item()))
   212
   213    39.18M       62.00M    0.00B    0.00B      def maybe_save(t):
   214                                                   should_save = params.save_iter > 0 and t % params.save_iter == 0
   215                                                   should_save = should_save or t == params.num_iterations
   216                                                   if should_save:
   217                                                       output_filename, file_extension = os.path.splitext(params.output_image)
   218                                                       if t == params.num_iterations:
   219                                                           filename = output_filename + str(file_extension)
   220                                                       else:
   221                                                           filename = str(output_filename) + "_" + str(t) + str(file_extension)
   222                                                       disp = deprocess(img.clone())
   223
   224                                                       # Maybe perform postprocessing for color-independent style transfer
   225                                                       if params.original_colors == 1:
   226                                                           disp = original_colors(deprocess(content_image.clone()), disp)
   227
   228                                                       disp.save(str(filename))
   229
   230                                               # Function to evaluate loss and gradient. We run the net forward and
   231                                               # backward to get the gradient, and sum up losses from the loss modules.
   232                                               # optim.lbfgs internally handles iteration and calls this function many
   233                                               # times, so we manually count the number of iterations to handle printing
   234                                               # and saving intermediate results.
   235    39.18M       62.00M    0.00B    0.00B      num_calls = [0]
   236    39.18M       62.00M    0.00B    0.00B      def feval():
   237                                                   num_calls[0] += 1
   238                                                   optimizer.zero_grad()
   239                                                   net(img)
   240                                                   loss = 0
   241
   242                                                   for mod in content_losses:
   243                                                       loss += mod.loss.to(backward_device)
   244                                                   for mod in style_losses:
   245                                                       loss += mod.loss.to(backward_device)
   246                                                   if params.tv_weight > 0:
   247                                                       for mod in tv_losses:
   248                                                           loss += mod.loss.to(backward_device)
   249
   250                                                   loss.backward()
   251
   252                                                   maybe_save(num_calls[0])
   253                                                   maybe_print(num_calls[0], loss)
   254
   255                                                   return loss
   256
   257    39.18M       62.00M   39.18M   62.00M      optimizer, loopVal = setup_optimizer(img)
   258   496.02M      582.00M  456.84M  520.00M      while num_calls[0] <= loopVal:
   259   496.02M      582.00M    0.00B    0.00B           optimizer.step(feval)

nvidia-smi:

nvidia-smi: 915 MiB

Test 2:

ubuntu@ip-Address:~/neural-style-pt$ python3 neural_style.py -num_iterations 500 -gpu 0 -backend cudnn -seed 876 -image_size 512

File: neural_style.py
Function: main at line 56

Line # Max usage   Peak usage diff max diff peak  Line Contents
===============================================================
    56                                           @profile
    57                                           def main():
    58     0.00B        0.00B -563.65M -934.00M      dtype, multidevice, backward_device = setup_gpu()
    59
    60    76.39M      574.00M   76.39M  574.00M      cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, False)
    61
    62    78.64M      102.00M    2.25M -472.00M      content_image = preprocess(params.content_image, params.image_size).type(dtype)
    63    78.64M      102.00M    0.00B    0.00B      style_image_input = params.style_image.split(',')
    64    78.64M      102.00M    0.00B    0.00B      style_image_list, ext = [], [".jpg",".png"]
    65    78.64M      102.00M    0.00B    0.00B      for image in style_image_input:
    66    78.64M      102.00M    0.00B    0.00B          if os.path.isdir(image):
    67                                                       images = (image + "/" + file for file in os.listdir(image)
    68                                                       if os.path.splitext(file)[1].lower() in ext)
    69                                                       style_image_list.extend(images)
    70                                                   else:
    71    78.64M      102.00M    0.00B    0.00B              style_image_list.append(image)
    72    78.64M      102.00M    0.00B    0.00B      style_images_caffe = []
    73    80.94M      102.00M    2.30M    0.00B      for image in style_image_list:
    74    78.64M      102.00M   -2.30M    0.00B          style_size = int(params.image_size * params.style_scale)
    75    80.94M      102.00M    2.30M    0.00B          img_caffe = preprocess(image, style_size).type(dtype)
    76    80.94M      102.00M    0.00B    0.00B          style_images_caffe.append(img_caffe)
    77
    78    80.94M      102.00M    0.00B    0.00B      if params.init_image != None:
    79                                                   image_size = (content_image.size(2), content_image.size(3))
    80                                                   init_image = preprocess(params.init_image, image_size).type(dtype)
    81
    82                                               # Handle style blending weights for multiple style inputs
    83    80.94M      102.00M    0.00B    0.00B      style_blend_weights = []
    84    80.94M      102.00M    0.00B    0.00B      if params.style_blend_weights == None:
    85                                                   # Style blending not specified, so use equal weighting
    86    80.94M      102.00M    0.00B    0.00B          for i in style_image_list:
    87    80.94M      102.00M    0.00B    0.00B              style_blend_weights.append(1.0)
    88    80.94M      102.00M    0.00B    0.00B          for i, blend_weights in enumerate(style_blend_weights):
    89    80.94M      102.00M    0.00B    0.00B              style_blend_weights[i] = int(style_blend_weights[i])
    90                                               else:
    91                                                   style_blend_weights = params.style_blend_weights.split(',')
    92                                                   assert len(style_blend_weights) == len(style_image_list), \
    93                                                     "-style_blend_weights and -style_images must have the same number of elements!"
    94
    95                                               # Normalize the style blending weights so they sum to 1
    96    80.94M      102.00M    0.00B    0.00B      style_blend_sum = 0
    97    80.94M      102.00M    0.00B    0.00B      for i, blend_weights in enumerate(style_blend_weights):
    98    80.94M      102.00M    0.00B    0.00B          style_blend_weights[i] = float(style_blend_weights[i])
    99    80.94M      102.00M    0.00B    0.00B          style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
   100    80.94M      102.00M    0.00B    0.00B      for i, blend_weights in enumerate(style_blend_weights):
   101    80.94M      102.00M    0.00B    0.00B          style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
   102
   103    80.94M      102.00M    0.00B    0.00B      content_layers = params.content_layers.split(',')
   104    80.94M      102.00M    0.00B    0.00B      style_layers = params.style_layers.split(',')
   105
   106                                               # Set up the network, inserting style and content loss modules
   107    82.64M      184.00M    1.70M   82.00M      cnn = copy.deepcopy(cnn)
   108    82.64M      144.00M    0.00B  -40.00M      content_losses, style_losses, tv_losses = [], [], []
   109    82.64M      144.00M    0.00B    0.00B      next_content_idx, next_style_idx = 1, 1
   110    82.64M      144.00M    0.00B    0.00B      net = nn.Sequential()
   111    82.64M      144.00M    0.00B    0.00B      c, r = 0, 0
   112    82.64M      144.00M    0.00B    0.00B      if params.tv_weight > 0:
   113    82.64M      144.00M    0.00B    0.00B          tv_mod = TVLoss(params.tv_weight).type(dtype)
   114    82.64M      144.00M    0.00B    0.00B          net.add_module(str(len(net)), tv_mod)
   115    82.64M      144.00M    0.00B    0.00B          tv_losses.append(tv_mod)
   116
   117    82.64M      144.00M    0.00B    0.00B      for i, layer in enumerate(list(cnn), 1):
   118    82.64M      144.00M    0.00B    0.00B          if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
   119    82.64M      144.00M    0.00B    0.00B              if isinstance(layer, nn.Conv2d):
   120    82.64M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   121
   122    82.64M      144.00M    0.00B    0.00B                  if layerList['C'][c] in content_layers:
   123                                                               print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
   124                                                               loss_module = ContentLoss(params.content_weight)
   125                                                               net.add_module(str(len(net)), loss_module)
   126                                                               content_losses.append(loss_module)
   127
   128    82.64M      144.00M    0.00B    0.00B                  if layerList['C'][c] in style_layers:
   129                                                               print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
   130                                                               loss_module = StyleLoss(params.style_weight)
   131                                                               net.add_module(str(len(net)), loss_module)
   132                                                               style_losses.append(loss_module)
   133    82.64M      144.00M    0.00B    0.00B                  c+=1
   134
   135    82.64M      144.00M    0.00B    0.00B              if isinstance(layer, nn.ReLU):
   136    82.64M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   137
   138    82.64M      144.00M    0.00B    0.00B                  if layerList['R'][r] in content_layers:
   139    82.64M      144.00M    0.00B    0.00B                      print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
   140    82.64M      144.00M    0.00B    0.00B                      loss_module = ContentLoss(params.content_weight)
   141    82.64M      144.00M    0.00B    0.00B                      net.add_module(str(len(net)), loss_module)
   142    82.64M      144.00M    0.00B    0.00B                      content_losses.append(loss_module)
   143    82.64M      144.00M    0.00B    0.00B                      next_content_idx += 1
   144
   145    82.64M      144.00M    0.00B    0.00B                  if layerList['R'][r] in style_layers:
   146    82.64M      144.00M    0.00B    0.00B                      print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
   147    82.64M      144.00M    0.00B    0.00B                      loss_module = StyleLoss(params.style_weight)
   148    82.64M      144.00M    0.00B    0.00B                      net.add_module(str(len(net)), loss_module)
   149    82.64M      144.00M    0.00B    0.00B                      style_losses.append(loss_module)
   150    82.64M      144.00M    0.00B    0.00B                      next_style_idx += 1
   151    82.64M      144.00M    0.00B    0.00B                  r+=1
   152
   153    82.64M      144.00M    0.00B    0.00B              if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
   154    82.64M      144.00M    0.00B    0.00B                  net.add_module(str(len(net)), layer)
   155
   156    82.64M      144.00M    0.00B    0.00B      if multidevice:
   157                                                   net = setup_multi_device(net)
   158
   159                                               # Capture content targets
   160    82.64M      144.00M    0.00B    0.00B      for i in content_losses:
   161    82.64M      144.00M    0.00B    0.00B          i.mode = 'capture'
   162    82.64M      144.00M    0.00B    0.00B      print("Capturing content targets")
   163    82.64M      144.00M    0.00B    0.00B      print_torch(net, multidevice)
   164   375.34M      402.00M  292.70M  258.00M      net(content_image)
   165
   166                                               # Capture style targets
   167   375.34M      402.00M    0.00B    0.00B      for i in content_losses:
   168   375.34M      402.00M    0.00B    0.00B          i.mode = 'None'
   169
   170   390.92M      488.00M   15.58M   86.00M      for i, image in enumerate(style_images_caffe):
   171   375.34M      402.00M  -15.58M  -86.00M          print("Capturing style target " + str(i+1))
   172   375.34M      402.00M    0.00B    0.00B          for j in style_losses:
   173   375.34M      402.00M    0.00B    0.00B              j.mode = 'capture'
   174   375.34M      402.00M    0.00B    0.00B              j.blend_weight = style_blend_weights[i]
   175   390.92M      704.00M   15.58M  302.00M          net(style_images_caffe[i])
   176
   177                                               # Set all loss modules to loss mode
   178   390.92M      488.00M    0.00B -216.00M      for i in content_losses:
   179   390.92M      488.00M    0.00B    0.00B          i.mode = 'loss'
   180   390.92M      488.00M    0.00B    0.00B      for i in style_losses:
   181   390.92M      488.00M    0.00B    0.00B          i.mode = 'loss'
   182
   183                                               # Freeze the network in order to prevent
   184                                               # unnecessary gradient calculations
   185   390.92M      488.00M    0.00B    0.00B      for param in net.parameters():
   186   390.92M      488.00M    0.00B    0.00B          param.requires_grad = False
   187
   188                                               # Initialize the image
   189   390.92M      488.00M    0.00B    0.00B      if params.seed >= 0:
   190   390.92M      488.00M    0.00B    0.00B          torch.manual_seed(params.seed)
   191   390.92M      488.00M    0.00B    0.00B          torch.cuda.manual_seed(params.seed)
   192   390.92M      488.00M    0.00B    0.00B          torch.cuda.manual_seed_all(params.seed)
   193   390.92M      488.00M    0.00B    0.00B          torch.backends.cudnn.deterministic=True
   194   390.92M      488.00M    0.00B    0.00B      if params.init == 'random':
   195   390.92M      488.00M    0.00B    0.00B          B, C, H, W = content_image.size()
   196   393.17M      488.00M    2.25M    0.00B          img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
   197                                               elif params.init == 'image':
   198                                                   if params.init_image != None:
   199                                                       img = init_image.clone()
   200                                                   else:
   201                                                       img = content_image.clone()
   202   393.17M      488.00M    0.00B    0.00B      img = nn.Parameter(img.type(dtype))
   203
   204   393.17M      488.00M    0.00B    0.00B      def maybe_print(t, loss):
   205                                                   if params.print_iter > 0 and t % params.print_iter == 0:
   206                                                       print("Iteration " + str(t) + " / "+ str(params.num_iterations))
   207                                                       for i, loss_module in enumerate(content_losses):
   208                                                           print("  Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
   209                                                       for i, loss_module in enumerate(style_losses):
   210                                                           print("  Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
   211                                                       print("  Total loss: " + str(loss.item()))
   212
   213   393.17M      488.00M    0.00B    0.00B      def maybe_save(t):
   214                                                   should_save = params.save_iter > 0 and t % params.save_iter == 0
   215                                                   should_save = should_save or t == params.num_iterations
   216                                                   if should_save:
   217                                                       output_filename, file_extension = os.path.splitext(params.output_image)
   218                                                       if t == params.num_iterations:
   219                                                           filename = output_filename + str(file_extension)
   220                                                       else:
   221                                                           filename = str(output_filename) + "_" + str(t) + str(file_extension)
   222                                                       disp = deprocess(img.clone())
   223
   224                                                       # Maybe perform postprocessing for color-independent style transfer
   225                                                       if params.original_colors == 1:
   226                                                           disp = original_colors(deprocess(content_image.clone()), disp)
   227
   228                                                       disp.save(str(filename))
   229
   230                                               # Function to evaluate loss and gradient. We run the net forward and
   231                                               # backward to get the gradient, and sum up losses from the loss modules.
   232                                               # optim.lbfgs internally handles iteration and calls this function many
   233                                               # times, so we manually count the number of iterations to handle printing
   234                                               # and saving intermediate results.
   235   393.17M      488.00M    0.00B    0.00B      num_calls = [0]
   236   393.17M      488.00M    0.00B    0.00B      def feval():
   237                                                   num_calls[0] += 1
   238                                                   optimizer.zero_grad()
   239                                                   net(img)
   240                                                   loss = 0
   241
   242                                                   for mod in content_losses:
   243                                                       loss += mod.loss.to(backward_device)
   244                                                   for mod in style_losses:
   245                                                       loss += mod.loss.to(backward_device)
   246                                                   if params.tv_weight > 0:
   247                                                       for mod in tv_losses:
   248                                                           loss += mod.loss.to(backward_device)
   249
   250                                                   loss.backward()
   251
   252                                                   maybe_save(num_calls[0])
   253                                                   maybe_print(num_calls[0], loss)
   254
   255                                                   return loss
   256
   257   393.17M      488.00M  393.17M  488.00M      optimizer, loopVal = setup_optimizer(img)
   258   563.65M      636.00M  170.48M  148.00M      while num_calls[0] <= loopVal:
   259   563.65M      934.00M    0.00B  298.00M           optimizer.step(feval)

nvidia-smi:

nvidia-smi: 1271 MiB
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment