Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created April 23, 2020 12:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save khanhnamle1994/7db2f0f77772eebd07b24be98b53a1b8 to your computer and use it in GitHub Desktop.
Save khanhnamle1994/7db2f0f77772eebd07b24be98b53a1b8 to your computer and use it in GitHub Desktop.
Neural Collaborative Filtering model
import torch
from layer import FeaturesEmbedding, MultiLayerPerceptron
class NeuralCollaborativeFiltering(torch.nn.Module):
"""
A Pytorch implementation of Neural Collaborative Filtering.
Reference:
X He, et al. Neural Collaborative Filtering, 2017.
"""
def __init__(self, field_dims, user_field_idx, item_field_idx, embed_dim, mlp_dims, dropout):
super().__init__()
self.user_field_idx = user_field_idx
self.item_field_idx = item_field_idx
self.embedding = FeaturesEmbedding(field_dims, embed_dim)
self.embed_output_dim = len(field_dims) * embed_dim
self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False)
self.fc = torch.nn.Linear(mlp_dims[-1] + embed_dim, 1)
def forward(self, x):
"""
:param x: Long tensor of size ``(batch_size, num_user_fields)``
"""
x = self.embedding(x)
user_x = x[:, self.user_field_idx].squeeze(1)
item_x = x[:, self.item_field_idx].squeeze(1)
x = self.mlp(x.view(-1, self.embed_output_dim))
gmf = user_x * item_x
x = torch.cat([gmf, x], dim=1)
x = self.fc(x).squeeze(1)
return torch.sigmoid(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment