Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Embeddings Network Class
class EmbeddingNet(nn.Module):
"""
Creates a dense network with embedding layers.
Args:
n_users:
Number of unique users in the dataset.
n_movies:
Number of unique movies in the dataset.
n_factors:
Number of columns in the embeddings matrix.
embedding_dropout:
Dropout rate to apply right after embeddings layer.
hidden:
A single integer or a list of integers defining the number of
units in hidden layer(s).
dropouts:
A single integer or a list of integers defining the dropout
layers rates applyied right after each of hidden layers.
"""
def __init__(self, n_users, n_movies,
n_factors=50, embedding_dropout=0.02,
hidden=10, dropouts=0.2):
super().__init__()
hidden = get_list(hidden)
dropouts = get_list(dropouts)
n_last = hidden[-1]
def gen_layers(n_in):
nonlocal hidden, dropouts
assert len(dropouts) <= len(hidden)
for n_out, rate in zip_longest(hidden, dropouts):
yield nn.Linear(n_in, n_out)
yield nn.ReLU()
if rate is not None and rate > 0.:
yield nn.Dropout(rate)
n_in = n_out
self.u = nn.Embedding(n_users, n_factors)
self.m = nn.Embedding(n_movies, n_factors)
self.drop = nn.Dropout(embedding_dropout)
self.hidden = nn.Sequential(*list(gen_layers(n_factors * 2)))
self.fc = nn.Linear(n_last, 1)
self._init()
def forward(self, users, movies, minmax=None):
features = torch.cat([self.u(users), self.m(movies)], dim=1)
x = self.drop(features)
x = self.hidden(x)
out = torch.sigmoid(self.fc(x))
if minmax is not None:
min_rating, max_rating = minmax
out = out*(max_rating - min_rating + 1) + min_rating - 0.5
return out
def _init(self):
def init(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
self.u.weight.data.uniform_(-0.05, 0.05)
self.m.weight.data.uniform_(-0.05, 0.05)
self.hidden.apply(init)
init(self.fc)
def get_list(n):
if isinstance(n, (int, float)):
return [n]
elif hasattr(n, '__iter__'):
return list(n)
raise TypeError('layers configuraiton should be a single number or a list of numbers')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment