Skip to content

Instantly share code, notes, and snippets.

@zhpmatrix
Created March 8, 2019 13:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhpmatrix/d72788e6de0754d9c36f9e71d03d57b6 to your computer and use it in GitHub Desktop.
Save zhpmatrix/d72788e6de0754d9c36f9e71d03d57b6 to your computer and use it in GitHub Desktop.
"""
https://discuss.pytorch.org/t/multi-layer-rnn-with-dataparallel/4450/2
https://pytorch.org/docs/stable/nn.html
"""
import torch
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
class Net(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super(Net, self).__init__()
self.gru = torch.nn.GRU(input_size,hidden_size, num_layers=2, batch_first=False)
for p in self.gru.parameters():
torch.nn.init.normal_(p)
def forward(self, input_, h0):
output, ht = self.gru(input_,h0)
return output, ht
if __name__ == '__main__':
model = torch.nn.DataParallel(Net(10,200), device_ids = [0,1], dim=1).cuda()
input_ = torch.randn(5,3,10)
h0 = torch.randn(2,3,200)
output,hn = model(input_,h0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment