Skip to content

Instantly share code, notes, and snippets.

@heljakka
Last active December 17, 2018 14:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save heljakka/ff8e0cd97da8ccf06c7973c06d5fe82e to your computer and use it in GitHub Desktop.
Save heljakka/ff8e0cd97da8ccf06c7973c06d5fe82e to your computer and use it in GitHub Desktop.
Fix to PyTorch issue 12671 which replicated the Network object for device 0. This made the hooks run on a replica instead of the original version, which caused the U matrix of spectral norm not to update correctly.
# Fix to PyTorch issue 12671 which replicated the Network object for device 0. This made the hooks run on a replica instead of the original version, which caused the U matrix of spectral norm not to update correctly.
# Replaces file /lib/python3.6/site-packages/torch/nn/parallel/replicate.py
# Originally PyTorch 0.4.1, you should diff first to check if your file needs some subtle modifications
# The official fix has now been issued and seems more comprehensive than this one (it fixes also an issue with .detach)
import torch.cuda.comm as comm
def replicate(network, devices, detach=False):
from ._functions import Broadcast
devices = tuple(devices)
num_replicas = len(devices)
params = list(network.parameters())
param_indices = {param: idx for idx, param in enumerate(params)}
param_copies = Broadcast.apply(devices, *params)
if len(params) > 0:
param_copies = [param_copies[i:i + len(params)]
for i in range(0, len(param_copies), len(params))]
buffers = list(network._all_buffers())
buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
buffer_copies = comm.broadcast_coalesced(buffers, devices)
modules = list(network.modules())
module_copies = [[] for device in devices]
module_indices = {}
for i, module in enumerate(modules):
module_indices[module] = i
for j in range(1, num_replicas): # AH
replica = module.__new__(type(module))
replica.__dict__ = module.__dict__.copy()
replica._parameters = replica._parameters.copy()
replica._buffers = replica._buffers.copy()
replica._modules = replica._modules.copy()
module_copies[j].append(replica)
module_copies[0].append(module) #AH
for i, module in enumerate(modules):
for key, child in module._modules.items():
if child is None:
for j in range(1, num_replicas): #AH
replica = module_copies[j][i]
replica._modules[key] = None
else:
module_idx = module_indices[child]
for j in range(1, num_replicas): #AH
replica = module_copies[j][i]
replica._modules[key] = module_copies[j][module_idx]
for key, param in module._parameters.items():
if param is None:
for j in range(1, num_replicas): #AH
replica = module_copies[j][i]
replica._parameters[key] = None
else:
param_idx = param_indices[param]
for j in range(1, num_replicas): #AH
replica = module_copies[j][i]
replica._parameters[key] = param_copies[j][param_idx].detach() \
if detach else param_copies[j][param_idx]
for key, buf in module._buffers.items():
if buf is None:
for j in range(1, num_replicas): #AH
replica = module_copies[j][i]
replica._buffers[key] = None
else:
buffer_idx = buffer_indices[buf]
for j in range(1, num_replicas): #AH
replica = module_copies[j][i]
replica._buffers[key] = buffer_copies[j][buffer_idx]
return [module_copies[j][0] for j in range(num_replicas)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment