Skip to content

Instantly share code, notes, and snippets.

@GINK03
Created February 16, 2019 12:05
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save GINK03/9ac2151159c6cc0887e5be36cd341dc2 to your computer and use it in GitHub Desktop.
pytorch_matrix_factorize
class MF(nn.Module):
def __init__(self, input_items, input_users):
super(MF, self).__init__()
print('item size', input_items)
self.l_b1 = nn.Embedding(num_embeddings=input_items, embedding_dim=768)
self.l_b2 = nn.Linear(
in_features=768, out_features=512, bias=True)
self.l_a1 = nn.Embedding(num_embeddings=input_users, embedding_dim=768)
self.l_a2 = nn.Linear(
in_features=768, out_features=512, bias=True)
self.l_l1 = nn.Linear(
in_features=512, out_features=1, bias=True)
def encode_item(self, x):
x = self.l_b1(x)
x = F.relu(self.l_b2(x))
return x
def encode_user(self, x):
x = self.l_a1(x)
x = F.relu(self.l_a2(x))
return x
def forward(self, inputs):
item_vec, user_vec = inputs
item_vec = self.encode_item(item_vec)
user_vec = self.encode_user(user_vec)
return F.relu(self.l_l1(user_vec * item_vec))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment