Last active
June 1, 2019 06:43
-
-
Save singhrahuldps/a7a747e56d2811d385ef815e2d4f3c21 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# neural net based on Embedding matrices | |
# model reference -> https://github.com/fastai/fastai/ | |
class EmbeddingModel(nn.Module): | |
def __init__(self, n_factors, n_users, n_items, y_range, initialise = 0.01): | |
super().__init__() | |
self.y_range = y_range | |
self.u_weight = nn.Embedding(n_users, n_factors) | |
self.i_weight = nn.Embedding(n_items, n_factors) | |
self.u_bias = nn.Embedding(n_users, 1) | |
self.i_bias = nn.Embedding(n_items, 1) | |
# initialise the weights of the embeddings | |
self.u_weight.weight.data.uniform_(-initialise, initialise) | |
self.i_weight.weight.data.uniform_(-initialise, initialise) | |
self.u_bias.weight.data.uniform_(-initialise, initialise) | |
self.i_bias.weight.data.uniform_(-initialise, initialise) | |
def forward(self, users, items): | |
# dot multiply the weights for the given user_id and item_id | |
dot = self.u_weight(users)* self.i_weight(items) | |
# sum the result of dot multiplication above and add both the bias terms | |
res = dot.sum(1) + self.u_bias(users).squeeze() + self.i_bias(items).squeeze() | |
# return the output in the given range | |
return torch.sigmoid(res) * (self.y_range[1]-self.y_range[0]) + self.y_range[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment