-
-
Save heljakka/ff8e0cd97da8ccf06c7973c06d5fe82e to your computer and use it in GitHub Desktop.
Fix to PyTorch issue 12671 which replicated the Network object for device 0. This made the hooks run on a replica instead of the original version, which caused the U matrix of spectral norm not to update correctly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Fix to PyTorch issue 12671 which replicated the Network object for device 0. This made the hooks run on a replica instead of the original version, which caused the U matrix of spectral norm not to update correctly. | |
# Replaces file /lib/python3.6/site-packages/torch/nn/parallel/replicate.py | |
# Originally PyTorch 0.4.1, you should diff first to check if your file needs some subtle modifications | |
# The official fix has now been issued and seems more comprehensive than this one (it fixes also an issue with .detach) | |
import torch.cuda.comm as comm | |
def replicate(network, devices, detach=False): | |
from ._functions import Broadcast | |
devices = tuple(devices) | |
num_replicas = len(devices) | |
params = list(network.parameters()) | |
param_indices = {param: idx for idx, param in enumerate(params)} | |
param_copies = Broadcast.apply(devices, *params) | |
if len(params) > 0: | |
param_copies = [param_copies[i:i + len(params)] | |
for i in range(0, len(param_copies), len(params))] | |
buffers = list(network._all_buffers()) | |
buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} | |
buffer_copies = comm.broadcast_coalesced(buffers, devices) | |
modules = list(network.modules()) | |
module_copies = [[] for device in devices] | |
module_indices = {} | |
for i, module in enumerate(modules): | |
module_indices[module] = i | |
for j in range(1, num_replicas): # AH | |
replica = module.__new__(type(module)) | |
replica.__dict__ = module.__dict__.copy() | |
replica._parameters = replica._parameters.copy() | |
replica._buffers = replica._buffers.copy() | |
replica._modules = replica._modules.copy() | |
module_copies[j].append(replica) | |
module_copies[0].append(module) #AH | |
for i, module in enumerate(modules): | |
for key, child in module._modules.items(): | |
if child is None: | |
for j in range(1, num_replicas): #AH | |
replica = module_copies[j][i] | |
replica._modules[key] = None | |
else: | |
module_idx = module_indices[child] | |
for j in range(1, num_replicas): #AH | |
replica = module_copies[j][i] | |
replica._modules[key] = module_copies[j][module_idx] | |
for key, param in module._parameters.items(): | |
if param is None: | |
for j in range(1, num_replicas): #AH | |
replica = module_copies[j][i] | |
replica._parameters[key] = None | |
else: | |
param_idx = param_indices[param] | |
for j in range(1, num_replicas): #AH | |
replica = module_copies[j][i] | |
replica._parameters[key] = param_copies[j][param_idx].detach() \ | |
if detach else param_copies[j][param_idx] | |
for key, buf in module._buffers.items(): | |
if buf is None: | |
for j in range(1, num_replicas): #AH | |
replica = module_copies[j][i] | |
replica._buffers[key] = None | |
else: | |
buffer_idx = buffer_indices[buf] | |
for j in range(1, num_replicas): #AH | |
replica = module_copies[j][i] | |
replica._buffers[key] = buffer_copies[j][buffer_idx] | |
return [module_copies[j][0] for j in range(num_replicas)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment