Created
February 3, 2020 17:48
-
-
Save myleott/a66ba69601cbd21a5a2218a33b6363f8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
platform = "gpu" if torch.cuda.is_available() else "tpu" | |
if platform == "tpu": | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.parallel_loader as pl | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
import torch_xla.debug.metrics as met | |
spawn_fn = xmp.spawn | |
nprocs = 8 | |
def device_fn(rank): | |
return xm.xla_device(), torch.bfloat16 | |
def ddp_fn(model, rank): | |
return model | |
def itr_fn(dataloader, device): | |
return pl.ParallelLoader(dataloader, [device]).per_device_loader(device) | |
def prep_fn(batch, device, dtype): | |
return batch | |
def step_fn(optimizer): | |
xm.optimizer_step(optimizer) | |
def metrics_fn(): | |
print(met.metrics_report()) | |
elif platform == "gpu": | |
import torch.multiprocessing as mp | |
spawn_fn = mp.spawn | |
nprocs = 4 | |
def device_fn(rank): | |
return "cuda:" + str(rank), torch.float16 | |
def ddp_fn(model, rank): | |
torch.distributed.init_process_group( | |
"nccl", | |
init_method="tcp://localhost:10234", | |
world_size=nprocs, | |
rank=rank, | |
) | |
return torch.nn.parallel.DistributedDataParallel( | |
model, device_ids=[rank], | |
) | |
def itr_fn(dataloader, device): | |
return dataloader | |
def prep_fn(batch, device, dtype): | |
return batch.to(device=device) | |
def step_fn(optimizer): | |
optimizer.step() | |
def metrics_fn(): | |
pass | |
else: | |
raise NotImplementedError | |
class Net(nn.Module): | |
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): | |
super().__init__() | |
self.embed = nn.Embedding( | |
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 | |
) | |
self.layers_a = nn.ModuleList([ | |
nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection | |
nn.Linear(3*embed_dim, embed_dim), # skip self-attention | |
nn.Linear(embed_dim, embed_dim), # output projection | |
nn.Dropout(), | |
) | |
for i in range(num_layers) | |
]) | |
self.layers_b = nn.ModuleList([ | |
nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
nn.Linear(embed_dim, 4*embed_dim), # FFN | |
nn.ReLU(), | |
nn.Linear(4*embed_dim, embed_dim), # FFN | |
nn.Dropout(0.1), | |
) | |
for i in range(num_layers) | |
]) | |
self.out_proj = nn.Linear(embed_dim, num_embed) | |
def forward(self, tokens): | |
x = self.embed(tokens) | |
for layer_a, layer_b in zip(self.layers_a, self.layers_b): | |
x = x + layer_a(x) | |
x = x + layer_b(x) | |
x = self.out_proj(x) | |
return x | |
def main(rank): | |
bsz = 128 // nprocs | |
seqlen = 512 | |
warmup_steps = 50 | |
measurement_steps = 50 | |
print("initializing dataloader") | |
item = torch.arange(1, seqlen + 1, dtype=torch.long) | |
dataloader = torch.utils.data.DataLoader( | |
[item for _ in range(bsz * 1000)], | |
batch_size=bsz, | |
num_workers=4, | |
) | |
device, dtype = device_fn(rank) | |
print("initializing model/opt/loss") | |
model = Net().to(device=device, dtype=dtype) | |
model = ddp_fn(model, rank) | |
optimizer = optim.SGD(model.parameters(), lr=0.001) | |
loss_fn = nn.CrossEntropyLoss(ignore_index=0) | |
print("num model params: {}".format(sum(p.numel() for p in model.parameters()))) | |
print("initializing paraloader") | |
itr = itr_fn(dataloader, device) | |
print("beginning warmup") | |
for i, batch in enumerate(itr): | |
if i == warmup_steps: | |
print("end warmup, begin measurement") | |
start_time = time.time() | |
batch = prep_fn(batch, device, dtype) | |
x = model(batch) | |
loss = loss_fn( | |
x.view(-1, x.size(-1)), | |
target=batch.view(-1) | |
) | |
loss.backward() | |
step_fn(optimizer) | |
if i == warmup_steps + measurement_steps: | |
measured_time = time.time() - start_time | |
print( | |
"end measurement, time for rank {}: {}" | |
.format(rank, measured_time) | |
) | |
break | |
metrics_fn() | |
if __name__ == "__main__": | |
spawn_fn(main, args=(), nprocs=nprocs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment