Skip to content

Instantly share code, notes, and snippets.

Last active November 4, 2020 20:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save santhalakshminarayana/39c738fcaad95619b34d83345c13da08 to your computer and use it in GitHub Desktop.
Save santhalakshminarayana/39c738fcaad95619b34d83345c13da08 to your computer and use it in GitHub Desktop.
Quotes LSTM model - Medium
def get_batches_x(tot_seq, batch_size):
ind = np.random.permutation(tot_seq).tolist()
i = 0
for i in range(0, tot_seq, batch_size):
batch_ids = ind[i:i+batch_size]
yield X[batch_ids], Y[batch_ids]
class Quote_Generator(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_len):
super(Quote_Generator, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(embed_size, hidden_size, batch_first = True).to(device)
self.dropout = nn.Dropout(0.4)
self.dense = nn.Linear(hidden_size*5, vocab_len).to(device)
def forward(self, x, prev_state):
output, state = self.lstm(x)
output = self.dropout(output)
logits = self.dense(output.reshape(-1, hidden_size*5))
return logits, state
def zero_states(self, batch_size):
return (torch.zeros(1, batch_size, self.hidden_size).to(device),
torch.zeros(1, batch_size, self.hidden_size).to(device))
def entropy_loss(y, y_hat):
y_hat = F.softmax(y_hat, dim = 1)
ll = - (y * torch.log(y_hat))
return torch.sum(ll, dim = 1).mean().to(device)
def qt_train(qt_gen):
epochs = 101
batch_size = 4096
losses = []
optimizer = torch.optim.Adam(qt_gen.parameters(), lr=0.001)
for epoch in tqdm(range(epochs)):
batches = get_batches_x(tot_seq, batch_size)
h_h, h_c = qt_gen.zero_states(batch_size)
for x,y in batches:
x = torch.tensor(x).float().to(device)
y = torch.tensor(y).long().to(device)
logits, (h_h, h_c) = qt_gen(x, (h_h, h_c))
loss = entropy_loss(y, logits)
_ = nn.utils.clip_grad_norm_(qt_gen.parameters(), 5)
if (epoch) % 10 == 0:
print(f"Epoch : {epoch} ----> Loss : {np.array(losses).mean()}")
losses = []
embed_size = 128
hidden_size = 64
qt_gen = Quote_Generator(embed_size, hidden_size, vocab_len).to(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment