Skip to content

Instantly share code, notes, and snippets.

@Yash-567
Created July 7, 2020 07:27
Show Gist options
  • Save Yash-567/04d104b1e2011801bef42c8794c513bb to your computer and use it in GitHub Desktop.
Save Yash-567/04d104b1e2011801bef42c8794c513bb to your computer and use it in GitHub Desktop.
K = 10 # Number of latent features
N = len(set(df.userId.values)) # Number of users
M = len(set(df.movieId.values)) # Number of movies
from torch import nn
import torch
import torch.nn.functional as F
class Network(nn.Module):
def __init__(self):
super().__init__()
self.u_embeddings = nn.Embedding(N, K)
self.m_embeddings = nn.Embedding(M, K)
self.hidden = nn.Linear(2*K, 400)
self.output = nn.Linear(400, 1)
def forward(self, u,m):
u = u.view(-1,1)
m = m.view(-1,1)
u_embedding = self.u_embeddings(u)
m_embedding = self.m_embeddings(m)
u_embedding = u_embedding.view(-1,10)
m_embedding = m_embedding.view(-1,10)
x = torch.cat([u_embedding, m_embedding], dim=1)
x = F.relu(self.hidden(x))
x = self.output(x)
return x
def get_movie_embeddings(self, m):
m = m.view(-1, 1)
m_embedding = self.m_embeddings(m)
return m_embedding
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment