Skip to content

Instantly share code, notes, and snippets.

@dlibenzi
Created May 17, 2020 15:14
Show Gist options
  • Save dlibenzi/d049894d6e1cde15f0d39a6149a1fdb5 to your computer and use it in GitHub Desktop.
Save dlibenzi/d049894d6e1cde15f0d39a6149a1fdb5 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class XlaLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, batch_first=False, pad_value=0):
super(XlaLSTM, self).__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
if batch_first:
self.batch_dim, self.sequence_dim = 0, 1
else:
self.batch_dim, self.sequence_dim = 1, 0
self.pad_value = pad_value
self.weight_ih = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.weight_hh = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
for p in self.parameters():
if p.data.ndimension() >= 2:
nn.init.xavier_uniform_(p.data)
else:
nn.init.zeros_(p.data)
def sequence_slice(self, t, embed_out, embed_in):
ot = embed_out[:, t, :] if self.sequence_dim == 1 else embed_out[t, :, :]
it = embed_in[:, t] if self.sequence_dim == 1 else embed_in[t, :]
return ot, it
def sizes(self, embed_out):
size = embed_out.size()
return size[self.batch_dim], size[self.sequence_dim]
def forward(self, embed_out, embed_in, init_states=None):
bs, seq_sz = self.sizes(embed_out)
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size, device=embed_out.device),
torch.zeros(bs, self.hidden_size, device=embed_out.device))
else:
h_t, c_t = init_states
hstate = []
HS = self.hidden_size
for t in range(0, seq_sz):
feat, iseq = self.sequence_slice(t, embed_out, embed_in)
gates = feat @ self.weight_ih + h_t @ self.weight_hh + self.bias
i_t, f_t, g_t, o_t = (
torch.sigmoid(gates[:, :HS]),
torch.sigmoid(gates[:, HS:HS * 2]),
torch.tanh(gates[:, HS * 2:HS * 3]),
torch.sigmoid(gates[:, HS * 3:]),
)
fwd = iseq.unsqueeze(1) != self.pad_value
c_t = torch.where(fwd, f_t * c_t + i_t * g_t, c_t)
h_t = torch.where(fwd, o_t * torch.tanh(c_t), h_t)
hstate.append(h_t.unsqueeze(self.sequence_dim))
hstate = torch.cat(hstate, dim=self.sequence_dim)
return hstate, (h_t, c_t)
#### TEST
import collections
import csv
import nltk
import random
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.utils.gcsfs as gcs
def maybe_int(v):
try:
return int(v)
except ValueError:
pass
def read_dataset(path):
with gcs.generic_open(path, mode='r') as fd:
reader = csv.DictReader(fd)
vocab = {'PAD': 0}
sentences, max_target = [], -1
for fields in reader:
target = maybe_int(fields['target'])
if target is None:
continue
sentence = []
for tok in nltk.word_tokenize(fields['sentence']):
ltok = tok.lower()
tid = vocab.get(ltok, None)
if tid is None:
tid = len(vocab)
vocab[ltok] = tid
sentence.append(tid)
if sentence:
sentences.append((sentence, target))
max_target = max(max_target, target)
return sentences, vocab, max_target
def split_and_pad(sentences, splits, pad_value=0):
splits = sorted(splits)
split_dict = collections.defaultdict(list)
discarded = 0
for sentence, target in sentences:
lsent = len(sentence)
for i in range(0, len(splits)):
if lsent < splits[i]:
break
seqlen = splits[i]
if lsent > seqlen:
discarded += 1
continue
padded_sentence = sentence + [pad_value] * (seqlen - lsent)
split_dict[seqlen].append((padded_sentence, target))
return split_dict, discarded
def make_batches(split_dict, batch_size):
batch_dict = dict()
for seqlen, slist in split_dict.items():
batches = []
i = 0
while i + batch_size <= len(slist):
batches.append(slist[i:i + batch_size])
i += batch_size
if i < len(slist):
batch = slist[i:]
while len(batch) < batch_size:
batch.append(slist[random.randint(0, len(slist) - 1)])
batches.append(batch)
batch_dict[seqlen] = batches
return batch_dict
def to_one_hot(y, n_dims, dtype=torch.float32):
scatter_dim = len(y.size())
y_tensor = y.view(*y.size(), -1)
zeros = torch.zeros(*y.size(), n_dims, dtype=dtype, device=y.device)
return zeros.scatter(scatter_dim, y_tensor, 1)
def gen_tensors(batch_dict, target_dims, shuffle=True):
slist = []
for seqlen in sorted(batch_dict.keys()):
slist += batch_dict[seqlen]
if shuffle:
random.shuffle(slist)
tensors = []
for bseq in slist:
sentence_data, target_data = [], []
for sentence, target in bseq:
sentence_data.append(sentence)
target_data.append(target)
sentence_tensor = torch.tensor(sentence_data, dtype=torch.int64)
target_tensor = torch.tensor(target_data, dtype=torch.int64)
onehot_tensor = to_one_hot(target_tensor, target_dims)
tensors.append((sentence_tensor, onehot_tensor))
return tensors
class TestClassifier(nn.Module):
def __init__(self,
vocab_size,
embedding_dim,
hidden_dim,
output_dim,
padding_idx=None):
super().__init__()
self.embedding = nn.Embedding(
vocab_size, embedding_dim, padding_idx=padding_idx)
self.lstm = XlaLSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
self.act = nn.Sigmoid()
def forward(self, sentence_tensor):
embedded = self.embedding(sentence_tensor)
output, (hidden, cell) = self.lstm(embedded, sentence_tensor)
dense_outputs = self.fc(hidden)
return self.act(dense_outputs)
def test_model(path,
device,
splits,
batch_size,
embed_size=8,
per_target_hs_size=256,
lr=0.01,
momentum=None,
epochs=1,
log_interval=10):
MIN_HS_SIZE = 128
MAX_HS_SIZE = 1024 * 32
sentences, vocab, max_target = read_dataset(path)
split_dict, discarded = split_and_pad(sentences, splits)
batch_dict = make_batches(split_dict, batch_size)
train_tensors = gen_tensors(batch_dict, max_target + 1)
hs_size = min(MAX_HS_SIZE,
max((max_target + 1) * per_target_hs_size, MIN_HS_SIZE))
model = TestClassifier(
len(vocab), embed_size, hs_size, max_target + 1, padding_idx=0)
model.to(device)
model.train()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.MSELoss().to(device)
for epoch in range(0, epochs):
n = 0
for sentence_tensor, target_tensor in train_tensors:
sentence_tensor = sentence_tensor.to(device)
target_tensor = target_tensor.to(device)
optimizer.zero_grad()
output = model(sentence_tensor)
loss = criterion(output, target_tensor)
loss.backward()
xm.optimizer_step(optimizer, barrier=True)
# print('OUTPUT:\n', torch_xla._XLAC._get_xla_tensors_text([output]))
n += 1
if n % log_interval == 0:
print('[{}] Loss: {:.4f}'.format(epoch, loss.cpu().item()))
torch.manual_seed(11)
device = xm.xla_device()
cvs_path = 'gs://davide-stg1/lstm_test_data.csv'
nltk.download('punkt')
test_model(
cvs_path,
device, (8, 16, 32, 64),
batch_size=16,
embed_size=128,
lr=0.01,
momentum=0.9,
epochs=10)
print(met.metrics_report())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment