Created
October 18, 2019 01:16
-
-
Save polm/f512451edc0a670a334c76ddef99815c to your computer and use it in GitHub Desktop.
Simplified version of embedding training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html | |
# http://bytepawn.com/hacker-news-embeddings-with-pytorch.html | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import json | |
from random import choice, random, shuffle | |
import numpy as np | |
torch.manual_seed(1) | |
class DummyEmbedding(torch.nn.Module): | |
def __init__(self, num_alpha, num_beta, embedding_dim=64): | |
super(DummyEmbedding, self).__init__() | |
self.alpha_embedding = torch.nn.Embedding(num_alpha, embedding_dim, max_norm=1.0) | |
self.beta_embedding = torch.nn.Embedding(num_beta, embedding_dim, max_norm=1.0) | |
self.embedding_dim = embedding_dim | |
def forward(self, batch): | |
t1 = self.alpha_embedding(torch.LongTensor([v[0] for v in batch])) | |
t2 = self.beta_embedding(torch.LongTensor([v[1] for v in batch])) | |
dot_products = torch.bmm( | |
t1.contiguous().view(len(batch), 1, self.embedding_dim), | |
t2.contiguous().view(len(batch), self.embedding_dim, 1) | |
) | |
out = dot_products.contiguous().view(len(batch)) | |
return out | |
def build_minibatch(num_positives, num_negatives): | |
minibatch = [] | |
for _ in range(num_positives + num_negatives): | |
minibatch.append( [ | |
int(random() * 100), | |
int(random() * 100), | |
choice([-1, 1]) ] ) | |
shuffle(minibatch) | |
return minibatch | |
model = DummyEmbedding(100, 100, 64) | |
optimizer = torch.optim.Adam(model.parameters()) | |
loss_function = torch.nn.MSELoss(reduction='mean') | |
for i in range(50): | |
for j in range(100): | |
optimizer.zero_grad() | |
#XXX make these numbers smaller and it works | |
minibatch = build_minibatch(500, 500) | |
target = torch.FloatTensor([v[2] for v in minibatch]) | |
y = model.forward(minibatch) | |
loss = loss_function(y, target) | |
if i == 0 and j == 0: | |
print('r: loss = %.3f' % float(loss)) | |
loss.backward(retain_graph=False) | |
optimizer.step() | |
print('%s: loss = %.3f' % (i, float(loss))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment