Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Created December 16, 2019 00:56
Show Gist options
  • Save lucidrains/366ac632ea7355c6b0563687eb5bc129 to your computer and use it in GitHub Desktop.
Save lucidrains/366ac632ea7355c6b0563687eb5bc129 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import torch.nn.functional as F
class Recommend(nn.Module):
def __init__(self, num_items, num_users, dims):
super().__init__()
self.user_embed = nn.Embedding(num_users, dims)
self.item_embed = nn.Embedding(num_items, dims)
self.net = nn.Sequential(
nn.Linear(2 * dims, 4 * dims),
nn.LeakyReLU(),
nn.Linear(4 * dims, 1),
nn.Sigmoid()
)
def forward(self, x):
f = torch.cat((self.user_embed(x[:, 0]), self.item_embed(x[:, 1])), dim=1)
return self.net(f)
r = Recommend(12, 1000, 512)
x = torch.tensor([[0, 0], [0, 1], [1, 0]])
y = torch.tensor([1, 0, 1]).float()
# user 0 likes item 0
# user 0 dislikes item 1
# user 1 likes item 0
output = r(dataset)
loss = F.binary_cross_entropy(output, y)
loss.backward()
# run above in training loop
r(torch.tensor[[1, 1]])
# predict if user1 likes item 1
cluster(r.user_embed.weight)
cluster(r.item_embed.weight)
# cluster user and item embeddings to find similar users and items
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment