Skip to content

Instantly share code, notes, and snippets.

@myleott
Last active February 3, 2020 16:50
Show Gist options
  • Save myleott/9017f443ab9d86ecf779d57fde58e1a2 to your computer and use it in GitHub Desktop.
Save myleott/9017f443ab9d86ecf779d57fde58e1a2 to your computer and use it in GitHub Desktop.
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
#platform = "tpu"
platform = "gpu"
if platform == "tpu":
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
spawn_fn = xmp.spawn
def device_fn():
return xm.xla_device(), torch.bfloat16
def itr_fn(dataloader, device):
return pl.ParallelLoader(dataloader, [device]).per_device_loader(device)
def prep_fn(batch, device, dtype):
return batch
def step_fn(optimizer):
xm.optimizer_step(optimizer)
elif platform == "gpu":
import torch.multiprocessing as mp
spawn_fn = mp.spawn
def device_fn():
return "cuda", torch.float16
def itr_fn(dataloader, device):
return dataloader
def prep_fn(batch, device, dtype):
return batch.to(device=device)
def step_fn(optimizer):
optimizer.step()
else:
raise NotImplementedError
class Net(nn.Module):
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
super().__init__()
self.embed = nn.Embedding(
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
)
self.layers_a = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
nn.Linear(3*embed_dim, embed_dim), # skip self-attention
nn.Linear(embed_dim, embed_dim), # output projection
nn.Dropout(),
)
for i in range(num_layers)
])
self.layers_b = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4*embed_dim), # FFN
nn.ReLU(),
nn.Linear(4*embed_dim, embed_dim), # FFN
nn.Dropout(0.1),
)
for i in range(num_layers)
])
self.out_proj = nn.Linear(embed_dim, num_embed)
def forward(self, tokens):
x = self.embed(tokens)
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
x = x + layer_a(x)
x = x + layer_b(x)
x = self.out_proj(x)
return x
def main(rank):
bsz = 8
seqlen = 512
warmup_steps = 50
measurement_steps = 50
print("initializing dataloader")
item = torch.arange(1, seqlen + 1, dtype=torch.long)
dataloader = torch.utils.data.DataLoader(
[item for _ in range(bsz * 1000)],
batch_size=bsz,
)
device, dtype = device_fn()
print("initializing model/opt/loss")
model = Net().to(device=device, dtype=dtype)
optimizer = optim.SGD(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
print("num model params: {}".format(sum(p.numel() for p in model.parameters())))
print("initializing paraloader")
itr = itr_fn(dataloader, device)
print("beginning warmup")
for i, batch in enumerate(itr):
if i == warmup_steps:
print("end warmup, begin measurement")
start_time = time.time()
batch = prep_fn(batch, device, dtype)
x = model(batch)
loss = loss_fn(
x.view(-1, x.size(-1)),
target=batch.view(-1)
)
loss.backward()
step_fn(optimizer)
if i == warmup_steps + measurement_steps:
measured_time = time.time() - start_time
print(
"end measurement, time for rank {}: {}"
.format(rank, measured_time)
)
break
if __name__ == "__main__":
spawn_fn(main, args=(), nprocs=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment