Skip to content

Instantly share code, notes, and snippets.

@myleott
Created February 3, 2020 17:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save myleott/a66ba69601cbd21a5a2218a33b6363f8 to your computer and use it in GitHub Desktop.
Save myleott/a66ba69601cbd21a5a2218a33b6363f8 to your computer and use it in GitHub Desktop.
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
platform = "gpu" if torch.cuda.is_available() else "tpu"
if platform == "tpu":
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.debug.metrics as met
spawn_fn = xmp.spawn
nprocs = 8
def device_fn(rank):
return xm.xla_device(), torch.bfloat16
def ddp_fn(model, rank):
return model
def itr_fn(dataloader, device):
return pl.ParallelLoader(dataloader, [device]).per_device_loader(device)
def prep_fn(batch, device, dtype):
return batch
def step_fn(optimizer):
xm.optimizer_step(optimizer)
def metrics_fn():
print(met.metrics_report())
elif platform == "gpu":
import torch.multiprocessing as mp
spawn_fn = mp.spawn
nprocs = 4
def device_fn(rank):
return "cuda:" + str(rank), torch.float16
def ddp_fn(model, rank):
torch.distributed.init_process_group(
"nccl",
init_method="tcp://localhost:10234",
world_size=nprocs,
rank=rank,
)
return torch.nn.parallel.DistributedDataParallel(
model, device_ids=[rank],
)
def itr_fn(dataloader, device):
return dataloader
def prep_fn(batch, device, dtype):
return batch.to(device=device)
def step_fn(optimizer):
optimizer.step()
def metrics_fn():
pass
else:
raise NotImplementedError
class Net(nn.Module):
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
super().__init__()
self.embed = nn.Embedding(
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
)
self.layers_a = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
nn.Linear(3*embed_dim, embed_dim), # skip self-attention
nn.Linear(embed_dim, embed_dim), # output projection
nn.Dropout(),
)
for i in range(num_layers)
])
self.layers_b = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4*embed_dim), # FFN
nn.ReLU(),
nn.Linear(4*embed_dim, embed_dim), # FFN
nn.Dropout(0.1),
)
for i in range(num_layers)
])
self.out_proj = nn.Linear(embed_dim, num_embed)
def forward(self, tokens):
x = self.embed(tokens)
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
x = x + layer_a(x)
x = x + layer_b(x)
x = self.out_proj(x)
return x
def main(rank):
bsz = 128 // nprocs
seqlen = 512
warmup_steps = 50
measurement_steps = 50
print("initializing dataloader")
item = torch.arange(1, seqlen + 1, dtype=torch.long)
dataloader = torch.utils.data.DataLoader(
[item for _ in range(bsz * 1000)],
batch_size=bsz,
num_workers=4,
)
device, dtype = device_fn(rank)
print("initializing model/opt/loss")
model = Net().to(device=device, dtype=dtype)
model = ddp_fn(model, rank)
optimizer = optim.SGD(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
print("num model params: {}".format(sum(p.numel() for p in model.parameters())))
print("initializing paraloader")
itr = itr_fn(dataloader, device)
print("beginning warmup")
for i, batch in enumerate(itr):
if i == warmup_steps:
print("end warmup, begin measurement")
start_time = time.time()
batch = prep_fn(batch, device, dtype)
x = model(batch)
loss = loss_fn(
x.view(-1, x.size(-1)),
target=batch.view(-1)
)
loss.backward()
step_fn(optimizer)
if i == warmup_steps + measurement_steps:
measured_time = time.time() - start_time
print(
"end measurement, time for rank {}: {}"
.format(rank, measured_time)
)
break
metrics_fn()
if __name__ == "__main__":
spawn_fn(main, args=(), nprocs=nprocs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment