Skip to content

Instantly share code, notes, and snippets.

@santhalakshminarayana
Created January 6, 2020 07:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save santhalakshminarayana/fdf8c6ed2ab2dda04d535cfbbc655501 to your computer and use it in GitHub Desktop.
Save santhalakshminarayana/fdf8c6ed2ab2dda04d535cfbbc655501 to your computer and use it in GitHub Desktop.
Quotes Glove Train - Medium
def f_x(x, x_max, alpha):
x = (x/x_max)**alpha
return torch.min(x, torch.ones_like(x)).to(device)
def weight_mse(w_x, x, log_x):
loss = w_x * F.mse_loss(x, log_x, reduction='none')
return torch.mean(loss).to(device)
def glove_train(glove):
epochs = 100
batch_size = 512
batch_count = 0
x_max = 1
alpha = 0.75
loss_trace = []
n_bathces = int(math.floor(len(occs)/batch_size))
optimizer = optim.Adagrad(glove.parameters(), lr=0.05)
for epoch in tqdm(range(epochs)):
batch_count = -1
for ind_1, ind_2, occ in get_batch(batch_size):
batch_count += 1
occ = torch.FloatTensor(occ).to(device)
optimizer.zero_grad()
y_hat = glove.forward(torch.tensor(ind_1).to(device),
torch.tensor(ind_2).to(device))
w_x = f_x(occ, x_max, alpha)
loss = weight_mse(w_x, y_hat, torch.log(occ))
loss.backward()
optimizer.step()
loss_trace.append(loss.item())
if batch_count == n_bathces:
break
print(f"Epoch : {epoch + 1} ----> Loss : {loss_trace[-1]}")
num_dim = 128
glove = Glove(vocab_len, num_dim).to(device)
glove_train(glove)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment