Skip to content

Instantly share code, notes, and snippets.

@ritchie46
Created June 20, 2018 09:28
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save ritchie46/a6d09965fc4d81047f1b0c52ebb35cdc to your computer and use it in GitHub Desktop.
Save ritchie46/a6d09965fc4d81047f1b0c52ebb35cdc to your computer and use it in GitHub Desktop.
minibatches in pytorch
"""
How to do minibatches for RNNs in pytorch
Assume we feed characters to the model and predict the language of the words.
"""
def prepare_batch(x, y):
# determine the maximum word length per batch and zero pad the tensors
n_max = max([a.shape[0] for a in x])
pad = np.zeros((n_max, len(x), x[0].shape[2]))
lengths = []
for i in range(len(x)):
lengths.append(x[i].shape[0])
# shape = (n-dtc, n-batch, n-features)
pad[:x[i].shape[0], i:i + 1, :] = x[i]
# mini-batch needs to be in decreasing order for pack_padded (pytorch)
lengths = np.array(lengths)
idx = np.argsort(lengths)[::-1]
return pad[:, idx, :], lengths[idx], y[idx]
# the tensors in x_train have various dimensions due to different length words
# x_train = list[ array(n_characters, n_batches == 1, n_features) ... array(n_characters, n_batches == 1, n_features))
# y_train = array([1, 12, 6, ... 3, 1]) (labels)
# pad, lengths, _ = prepare_batch(x_train[:10], y_train[:10])
## Model
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.GRU(input_size, hidden_size, num_layers=1)
self.linear1 = nn.Linear(hidden_size, output_size)
self.logsoftmax = nn.LogSoftmax(dim=1)
self.softmax = nn.Softmax(dim=1)
self.inference = False
def forward(self, x, lengths):
hidden = self.init_hidden(x)
# pack_padded_sequence so that padded items in the sequence won't be shown the rnn
x = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
x, hidden = self.rnn(x, hidden)
x, _ = torch.nn.utils.rnn.pad_packed_sequence(x)
lengths = lengths - 1
# select only the last output per word (many to one)
x = x[lengths, np.arange(x.shape[1]), :]
x = F.relu(x)
x = self.linear1(x)
if self.inference:
output = self.softmax(x)
else:
output = self.logsoftmax(x)
return output
def init_hidden(self, x):
if next(self.parameters()).is_cuda:
return torch.zeros(1, x.size(1), self.hidden_size).float().cuda()
return torch.zeros(1, x.size(1), self.hidden_size).float()
m = RNN(x_train[0].shape[2], 50, len(np.unique(y_train)))
m(torch.tensor(pad).float(), lengths).shape
>>> # some shape
criterion = nn.NLLLoss()
optim = torch.optim.Adam(m.parameters(), lr=0.001)
## Train
epochs = 25
batch_size = 50
print_iter = 100
m.cuda()
tboard = True
m.train(True)
def get_prediction(x, y):
pad, lengths, y = prepare_batch(x, y)
x = torch.tensor(pad).float()
if next(m.parameters()).is_cuda:
x = x.cuda()
lengths = torch.tensor(lengths).long().cuda()
return m(x, lengths), torch.tensor(y, dtype=torch.long).cuda()
def test_eval():
batch_pred, batch_y = get_prediction(x_test, y_test)
batch_pred = batch_pred.cpu().data.numpy().argmax(1)
batch_y = batch_y.cpu().data.numpy()
return batch_pred, batch_y
idx = np.arange(x_train.shape[0])
for epoch in range(epochs):
np.random.shuffle(idx)
x_train = x_train[idx]
y_train = y_train[idx]
current_batch = 0
for iteration in range(y.shape[0] // batch_size):
batch_x = x_train[current_batch: current_batch + batch_size]
batch_y = y_train[current_batch: current_batch + batch_size]
current_batch += batch_size
optim.zero_grad()
if len(batch_x) > 0:
batch_pred, batch_y = get_prediction(batch_x, batch_y)
loss = criterion(batch_pred, batch_y)
loss.backward()
optim.step()
if iteration % print_iter == 0:
with torch.no_grad():
m.train(False)
batch_pred, batch_y = test_eval()
f1 = f1_score(batch_y, batch_pred, average='weighted')
precision = precision_score(batch_y, batch_pred, average='weighted')
print(loss.item(), '\titeraton:', iteration, '\tepoch', epoch, 'f1', f1)
m.train(True)
@harkmug
Copy link

harkmug commented Jul 24, 2018

Hi,

A quick, naive question. In line 49, is the initial hidden state strictly needed (I think in pytorch > 0.3 that defaults to zeros, as specified in init_hidden, line 65)?

I am trying to understand if I am correct in assuming that the GRU automatically passes the hidden states between the sequences within a single batch item and there is no across-batch-item passing of hidden states (as might be the case in long continuous texts that are split into batches).

Thanks for your time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment