Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Last active October 1, 2018 15:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrdrozdov/3bcb412ff60151f0cb6caf95d0fcccaa to your computer and use it in GitHub Desktop.
Save mrdrozdov/3bcb412ff60151f0cb6caf95d0fcccaa to your computer and use it in GitHub Desktop.
check pin/multi-gpu with embeddings
import argparse
import json
import torch
import torch.nn as nn
from tqdm import tqdm
class Model(nn.Module):
def __init__(self, embedding_dim, hidden_dim, pin=False):
super(Model, self).__init__()
self.W = nn.Linear(embedding_dim, hidden_dim)
self.pin = pin
def forward(self, x):
if self.pin:
x = x.pin_memory()
z = self.W(x)
return z
class ModelWithEmbedding(nn.Module):
def __init__(self, embedding_dim, hidden_dim, embeddings, pin=False):
super(ModelWithEmbedding, self).__init__()
self.W = nn.Linear(embedding_dim, hidden_dim)
self.embed = nn.Embedding.from_pretrained(embeddings, freeze=True)
self.pin = pin
def forward(self, x):
if self.pin:
x = x.pin_memory()
emb = self.embed(x)
z = self.W(emb)
return z
def noop(x):
return x
class Trainer(object):
def __init__(self, hidden_dim, vocab_size, embedding_dim, embedding_gpu=False, embedding_in_model=False, pin_early=False, pin_first=False, pin_second=False):
embeddings = torch.FloatTensor(vocab_size, embedding_dim)
if embedding_in_model:
embed = noop
model = ModelWithEmbedding(embedding_dim, hidden_dim, embeddings, pin=pin_second)
else:
embed = nn.Embedding.from_pretrained(embeddings, freeze=True)
model = Model(embedding_dim, hidden_dim, pin=pin_second)
self.pin_early = pin_early
self.pin = pin_first
self.embed = embed
self.model = model
self.embedding_gpu = embedding_gpu
self.model.cuda()
if embedding_gpu:
self.embed.cuda()
def step(self, x):
if self.pin_early:
x = x.pin_memory()
if self.embedding_gpu:
x = x.cuda(async=True)
emb = self.embed(x)
if self.pin:
emb = emb.pin_memory()
out = torch.nn.parallel.data_parallel(self.model, (emb, ))
return out
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Dims
parser.add_argument('-v', '--vocab', default=1000000, type=int)
parser.add_argument('-b', '--batch_size', default=256, type=int)
parser.add_argument('-l', '--length', default=20, type=int)
parser.add_argument('-e', '--embedding_dim', default=300, type=int)
parser.add_argument('-d', '--hidden_dim', default=200, type=int)
parser.add_argument('-n', '--nbatches', default=10000, type=int)
# Config
parser.add_argument('--embedding_gpu', action='store_true')
parser.add_argument('--embedding_in_model', action='store_true')
parser.add_argument('--pin_early', action='store_true')
parser.add_argument('--pin_first', action='store_true')
parser.add_argument('--pin_second', action='store_true')
options = parser.parse_args()
print(json.dumps(options.__dict__, sort_keys=True, indent=4))
torch.manual_seed(11)
trainer = Trainer(options.hidden_dim,
options.vocab,
options.embedding_dim,
embedding_gpu=options.embedding_gpu,
embedding_in_model=options.embedding_in_model,
pin_early=options.pin_early,
pin_first=options.pin_first,
pin_second=options.pin_second)
for i in tqdm(range(options.nbatches)):
x = torch.LongTensor(options.batch_size * options.length, 1).random_(0, options.vocab-1)
trainer.step(x)
# Embeddings on Each GPU / Pin Memory before embedding
python demo_embeddings.py --pin_early --embedding_in_model
time: > 30m
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.130 Driver Version: 384.130 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 GeForce GTX 1080 Off | 00000000:02:00.0 Off | N/A |
| 0% 48C P2 55W / 198W | 2801MiB / 8114MiB | 99% Default |
+-------------------------------+----------------------+----------------------+
| 1 GeForce GTX 1080 Off | 00000000:03:00.0 Off | N/A |
| 37% 57C P2 52W / 180W | 1661MiB / 8114MiB | 99% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 4658 C python 2791MiB |
| 1 4658 C python 1651MiB |
+-----------------------------------------------------------------------------+
# Embeddings on Single GPU / Pin Memory before embedding
python demo_embeddings.py --pin_early --embedding_gpu
time: 0m41s
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.130 Driver Version: 384.130 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 GeForce GTX 1080 Off | 00000000:02:00.0 Off | N/A |
| 0% 39C P2 46W / 198W | 1671MiB / 8114MiB | 13% Default |
+-------------------------------+----------------------+----------------------+
| 1 GeForce GTX 1080 Off | 00000000:03:00.0 Off | N/A |
| 29% 47C P2 41W / 180W | 513MiB / 8114MiB | 10% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 17383 C python 1661MiB |
| 1 17383 C python 503MiB |
+-----------------------------------------------------------------------------+
# Embeddings on CPU / Pin Memory before embedding
python demo_embeddings.py --pin_early --embedding_gpu
time: 0m24s
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.130 Driver Version: 384.130 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 GeForce GTX 1080 Off | 00000000:02:00.0 Off | N/A |
| 0% 43C P2 55W / 198W | 523MiB / 8114MiB | 63% Default |
+-------------------------------+----------------------+----------------------+
| 1 GeForce GTX 1080 Off | 00000000:03:00.0 Off | N/A |
| 33% 51C P2 43W / 180W | 517MiB / 8114MiB | 39% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 15879 C python 513MiB |
| 1 15879 C python 507MiB |
+-----------------------------------------------------------------------------+
# Embeddings on CPU / No explicit Pinned Memory
python demo_embeddings.py
time: 0m24s
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.130 Driver Version: 384.130 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 GeForce GTX 1080 Off | 00000000:02:00.0 Off | N/A |
| 0% 41C P2 57W / 198W | 523MiB / 8114MiB | 69% Default |
+-------------------------------+----------------------+----------------------+
| 1 GeForce GTX 1080 Off | 00000000:03:00.0 Off | N/A |
| 30% 48C P2 55W / 180W | 517MiB / 8114MiB | 40% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 5763 C python 513MiB |
| 1 5763 C python 507MiB |
+-----------------------------------------------------------------------------+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment