Skip to content

Instantly share code, notes, and snippets.

@bjourne
Created June 21, 2020 22:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bjourne/0cc9ae5729ab78ce0c7765a5c5a207c0 to your computer and use it in GitHub Desktop.
Save bjourne/0cc9ae5729ab78ce0c7765a5c5a207c0 to your computer and use it in GitHub Desktop.
from observations import ptb
from time import sleep, time
from torch.nn import *
from torch.optim import *
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch_xla.core.xla_model import (get_ordinal,
is_master_ordinal,
master_print,
xla_device,
xrt_world_size)
from torch_xla.distributed.parallel_loader import ParallelLoader
from torch_xla.distributed.xla_multiprocessing import spawn
import torch
class RNN(Module):
def __init__(self, vocab_size, embed_size, hidden_size,
n_layers, emb_dropout):
super(RNN, self).__init__()
self.encoder = Embedding(vocab_size, embed_size)
self.lstm = LSTM(embed_size, hidden_size,
n_layers, batch_first = True)
self.linear = Linear(hidden_size, vocab_size)
self.drop = Dropout(emb_dropout)
def forward(self, x, state):
x2 = self.drop(self.encoder(x))
out, state = self.lstm(x2, state)
out = out.reshape(out.size(0)*out.size(1), out.size(2))
out = self.linear(out)
return out, state
def init_state(self, batch_size, device):
num_layers = self.lstm.num_layers
hidden_size = self.lstm.hidden_size
hs = torch.zeros(num_layers, batch_size, hidden_size)
cs = torch.zeros(num_layers, batch_size, hidden_size)
return hs.to(device), cs.to(device)
def text_to_tensor(text):
ix2ch = sorted(set(text))
ch2ix = {c : i for i, c in enumerate(ix2ch)}
seq = torch.LongTensor([ch2ix[c] for c in text])
return ix2ch, ch2ix, seq
def batchify(tensor, batch_size):
n_batches = tensor.size(0) // batch_size
tensor = tensor[:n_batches * batch_size]
return tensor.view(batch_size, -1)
def successor_samples(batched_tensor, seq_len):
for i in range(0, batched_tensor.size(1) - seq_len, seq_len):
x = batched_tensor[:, i:i+seq_len]
y = batched_tensor[:, (i+1):(i+1) + seq_len]
yield x, y
def load_data(ptb_path, batch_size, seq_len):
texts = ptb(ptb_path)
tensors = [text_to_tensor(text) for text in texts]
ix2ch, ch2ix, _ = tensors[0]
tensors = [batchify(t[2], batch_size) for t in tensors]
data = [list(successor_samples(t, seq_len)) for t in tensors]
return ix2ch, ch2ix, data
def fn(ix, flags):
batch_size = flags['batch_size']
seq_len = flags['seq_len']
if not is_master_ordinal():
rendezvous('download_once')
ix2ch, ch2ix, data = load_data('./data', batch_size, seq_len)
train_ds, valid_ds, test_ds = data
train_loader = DataLoader(
train_ds,
batch_size = None,
sampler = DistributedSampler(
train_ds,
num_replicas = xrt_world_size(),
rank = get_ordinal(),
shuffle = True),
shuffle = False,
num_workers = 8)
if is_master_ordinal():
rendezvous('download_once')
dev = xla_device()
model = RNN(len(ix2ch), 100, 512, 1, 0.1).to(dev)
crit = CrossEntropyLoss()
opt = SGD(model.parameters(), lr = 4)
for i in range(3):
start = time()
loader = ParallelLoader(train_loader, [dev]) \
.per_device_loader(dev)
state = model.init_state(batch_size, dev)
model.train()
for x, y in loader:
opt.zero_grad()
state = [s.detach() for s in state]
y_hat, state = model(x, state)
loss = crit(y_hat, y.reshape(-1))
loss.backward()
optimizer_step(opt)
elapsed = time() - start
master_print('%.2f seconds for %d batches.'
% (elapsed, len(train_ds)))
rendezvous('done')
if ix == 0:
sleep(0.5)
def main():
flags = {'batch_size' : 32, 'seq_len' : 320}
start = time()
spawn(fn, args = (flags,), nprocs = 8, start_method = 'fork')
elapsed = time() - start
print('Took %.2f seconds.' % elapsed)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment