Skip to content

Instantly share code, notes, and snippets.

@georgestanley
Created January 4, 2023 17:43
Show Gist options
  • Save georgestanley/838bbd365ac5255815721c7a0a428057 to your computer and use it in GitHub Desktop.
Save georgestanley/838bbd365ac5255815721c7a0a428057 to your computer and use it in GitHub Desktop.
Pytorch code snippet to parallelly train on multiple GPUs
def initialize_model(hidden_dim, hidden_layers, lr, device):
# As in line 11, pass your model to nn.DataParallel class
# Pytorch would now parallelize your model trainig if multiple GPU's are found during initialization.
input_dim = alph_len * 3
hidden_dim = hidden_dim
layer_dim = hidden_layers
output_dim = 2
model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim, device)
model = nn.DataParallel(model)
model = model.to(device)
learning_rate = lr
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
return model, criterion, optimizer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment