Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active October 15, 2019 19:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ProGamerGov/e495448fdc665a8570cf7b57257470aa to your computer and use it in GitHub Desktop.
Save ProGamerGov/e495448fdc665a8570cf7b57257470aa to your computer and use it in GitHub Desktop.
# Inspired by: https://github.com/torch/nn/blob/master/GPU.lua
# And: https://github.com/jcjohnson/neural-style/blob/master/neural_style.lua#L360
# As seen in: https://github.com/ProGamerGov/neural-style-pt
import torch
import torch.nn as nn
class ModelParallel(nn.Module):
r"""Splits a sequential network across multiple devices.
Args:
net (Module): a sequential model to be split across multiple devices
device_ids (list) list of zero-indexed GPU int and str c for CPU
net_splits (int or list of int): int or list of layer indices of where to split net
Example::
>>> net = ModelParallel(model, device_ids=[0, 1, 2], net_splits=[2,5])
>>> net = ModelParallel(model, device_ids=[c, 0], net_splits=[5]) # c is used for CPU ID
"""
def __init__(self, net, device_ids, device_splits):
super(ModelParallel, self).__init__()
self.device_list = self.name_devices(device_ids.split(','))
self.chunks = self.chunks_to_devices(self.split_net(net, device_splits.split(',')))
def name_devices(self, input_list):
r"""Convert a list of zero-indexed GPU and CPU devices to their PyTorch names.
Arguments:
input_list (list): List of zero-indexed GPU devices, and 'c' for CPU
"""
device_list = []
for i, device in enumerate(input_list):
if str(device).lower() != 'c':
device_list.append("cuda:" + str(device))
else:
device_list.append("cpu")
return device_list
# Split a network into chunks
def split_net(self, net, device_splits):
r"""Split a sequential net in chunks.
Arguments:
net (list): A list of Sequential nets
net_splits (int or list of int): Layer indices of where to split net
"""
chunks, cur_chunk = [], nn.Sequential()
for i, l in enumerate(net):
cur_chunk.add_module(str(i), net[i])
if str(i) in device_splits and device_splits != '':
del device_splits[0]
chunks.append(cur_chunk)
cur_chunk = nn.Sequential()
chunks.append(cur_chunk)
return chunks
def chunks_to_devices(self, chunks, device_list):
r"""Put a list of Sequential nets onto different devices.
Arguments:
chunks (list): A list of Sequential nets
device_list (list of string): A list of PyTorch device names
"""
for i, chunk in enumerate(chunks):
chunk.to(self.device_list[i])
return chunks
def c(self, input, i):
r"""Convert a tensor to a device from self.device_list[i]'s backend.
Arguments:
input (Tensor): A float or CUDA tensor
i (int): An index value for self.device_list
"""
if input.type() == 'torch.FloatTensor' and 'cuda' in self.device_list[i]:
input = input.type('torch.cuda.FloatTensor')
elif input.type() == 'torch.cuda.FloatTensor' and 'cpu' in self.device_list[i]:
input = input.type('torch.FloatTensor')
return input
def forward(self, input):
for i, chunk in enumerate(self.chunks):
if i < len(self.chunks) -1:
input = self.c(chunk(self.c(input, i).to(self.device_list[i])), i+1).to(self.device_list[i+1])
else:
input = chunk(input)
return input
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment