Skip to content

Instantly share code, notes, and snippets.

@polm
Created October 18, 2019 01:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save polm/f512451edc0a670a334c76ddef99815c to your computer and use it in GitHub Desktop.
Save polm/f512451edc0a670a334c76ddef99815c to your computer and use it in GitHub Desktop.
Simplified version of embedding training
#!/usr/bin/env python3
# https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html
# http://bytepawn.com/hacker-news-embeddings-with-pytorch.html
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import json
from random import choice, random, shuffle
import numpy as np
torch.manual_seed(1)
class DummyEmbedding(torch.nn.Module):
def __init__(self, num_alpha, num_beta, embedding_dim=64):
super(DummyEmbedding, self).__init__()
self.alpha_embedding = torch.nn.Embedding(num_alpha, embedding_dim, max_norm=1.0)
self.beta_embedding = torch.nn.Embedding(num_beta, embedding_dim, max_norm=1.0)
self.embedding_dim = embedding_dim
def forward(self, batch):
t1 = self.alpha_embedding(torch.LongTensor([v[0] for v in batch]))
t2 = self.beta_embedding(torch.LongTensor([v[1] for v in batch]))
dot_products = torch.bmm(
t1.contiguous().view(len(batch), 1, self.embedding_dim),
t2.contiguous().view(len(batch), self.embedding_dim, 1)
)
out = dot_products.contiguous().view(len(batch))
return out
def build_minibatch(num_positives, num_negatives):
minibatch = []
for _ in range(num_positives + num_negatives):
minibatch.append( [
int(random() * 100),
int(random() * 100),
choice([-1, 1]) ] )
shuffle(minibatch)
return minibatch
model = DummyEmbedding(100, 100, 64)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.MSELoss(reduction='mean')
for i in range(50):
for j in range(100):
optimizer.zero_grad()
#XXX make these numbers smaller and it works
minibatch = build_minibatch(500, 500)
target = torch.FloatTensor([v[2] for v in minibatch])
y = model.forward(minibatch)
loss = loss_function(y, target)
if i == 0 and j == 0:
print('r: loss = %.3f' % float(loss))
loss.backward(retain_graph=False)
optimizer.step()
print('%s: loss = %.3f' % (i, float(loss)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment