Last active
February 3, 2020 16:50
-
-
Save myleott/9017f443ab9d86ecf779d57fde58e1a2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
#platform = "tpu" | |
platform = "gpu" | |
if platform == "tpu": | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.parallel_loader as pl | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
spawn_fn = xmp.spawn | |
def device_fn(): | |
return xm.xla_device(), torch.bfloat16 | |
def itr_fn(dataloader, device): | |
return pl.ParallelLoader(dataloader, [device]).per_device_loader(device) | |
def prep_fn(batch, device, dtype): | |
return batch | |
def step_fn(optimizer): | |
xm.optimizer_step(optimizer) | |
elif platform == "gpu": | |
import torch.multiprocessing as mp | |
spawn_fn = mp.spawn | |
def device_fn(): | |
return "cuda", torch.float16 | |
def itr_fn(dataloader, device): | |
return dataloader | |
def prep_fn(batch, device, dtype): | |
return batch.to(device=device) | |
def step_fn(optimizer): | |
optimizer.step() | |
else: | |
raise NotImplementedError | |
class Net(nn.Module): | |
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): | |
super().__init__() | |
self.embed = nn.Embedding( | |
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 | |
) | |
self.layers_a = nn.ModuleList([ | |
nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection | |
nn.Linear(3*embed_dim, embed_dim), # skip self-attention | |
nn.Linear(embed_dim, embed_dim), # output projection | |
nn.Dropout(), | |
) | |
for i in range(num_layers) | |
]) | |
self.layers_b = nn.ModuleList([ | |
nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
nn.Linear(embed_dim, 4*embed_dim), # FFN | |
nn.ReLU(), | |
nn.Linear(4*embed_dim, embed_dim), # FFN | |
nn.Dropout(0.1), | |
) | |
for i in range(num_layers) | |
]) | |
self.out_proj = nn.Linear(embed_dim, num_embed) | |
def forward(self, tokens): | |
x = self.embed(tokens) | |
for layer_a, layer_b in zip(self.layers_a, self.layers_b): | |
x = x + layer_a(x) | |
x = x + layer_b(x) | |
x = self.out_proj(x) | |
return x | |
def main(rank): | |
bsz = 8 | |
seqlen = 512 | |
warmup_steps = 50 | |
measurement_steps = 50 | |
print("initializing dataloader") | |
item = torch.arange(1, seqlen + 1, dtype=torch.long) | |
dataloader = torch.utils.data.DataLoader( | |
[item for _ in range(bsz * 1000)], | |
batch_size=bsz, | |
) | |
device, dtype = device_fn() | |
print("initializing model/opt/loss") | |
model = Net().to(device=device, dtype=dtype) | |
optimizer = optim.SGD(model.parameters(), lr=0.001) | |
loss_fn = nn.CrossEntropyLoss(ignore_index=0) | |
print("num model params: {}".format(sum(p.numel() for p in model.parameters()))) | |
print("initializing paraloader") | |
itr = itr_fn(dataloader, device) | |
print("beginning warmup") | |
for i, batch in enumerate(itr): | |
if i == warmup_steps: | |
print("end warmup, begin measurement") | |
start_time = time.time() | |
batch = prep_fn(batch, device, dtype) | |
x = model(batch) | |
loss = loss_fn( | |
x.view(-1, x.size(-1)), | |
target=batch.view(-1) | |
) | |
loss.backward() | |
step_fn(optimizer) | |
if i == warmup_steps + measurement_steps: | |
measured_time = time.time() - start_time | |
print( | |
"end measurement, time for rank {}: {}" | |
.format(rank, measured_time) | |
) | |
break | |
if __name__ == "__main__": | |
spawn_fn(main, args=(), nprocs=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment