Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created May 28, 2018 07:48
Show Gist options
  • Save snakers4/aa4f24e9f6fb3266db42f0b98dfb7d0c to your computer and use it in GitHub Desktop.
Save snakers4/aa4f24e9f6fb3266db42f0b98dfb7d0c to your computer and use it in GitHub Desktop.
Playing with MLP + embeddings in PyTorch
import torch
import torch.nn as nn
from torch.autograd import Variable
class NaiveClassifier(nn.Module):
def __init__(self,
cat_sizes=None,
numerical_features=117,
mlp_sizes=[1024,2048,1024,512,256,128,2],
embedding_factor=3,
dropout_rate=0.1):
super(NaiveClassifier, self).__init__()
self.cat_sizes = cat_sizes
self.mlp_sizes = mlp_sizes
self.numerical_features = numerical_features
self.dropout_rate = dropout_rate
self.embedding_factor = embedding_factor
self.embeddings = {}
embedding_sizes = []
for i, cat_size in enumerate(self.cat_sizes):
emb,emb_size = self.create_emb(cat_size=cat_size,max_emb_size=50)
setattr(self, 'emb_{}'.format(i), emb)
embedding_sizes.append(emb_size)
self.numerical_features = [sum(embedding_sizes) + self.numerical_features] + self.mlp_sizes
modules = []
for i in range(0,len(self.numerical_features)-1):
modules = self.linear_block(modules,
self.numerical_features[i],
self.numerical_features[i+1])
self.classifier = nn.Sequential(*modules)
def create_emb(self,
cat_size = 7,
max_emb_size = 50):
emb_size = min([(cat_size+2)//self.embedding_factor, max_emb_size])
emb = nn.Embedding(num_embeddings = cat_size,
embedding_dim = emb_size)
return emb,emb_size
def forward(self,
x_num,
x_cat):
embedded = [getattr(self,'emb_{}'.format(i))(x_cat[:,i]) for i,cat_size in enumerate(self.cat_sizes)]
embedded = torch.cat(embedded,dim=1)
out = torch.cat([embedded,x_num],dim=1)
out = self.classifier(out)
return out
def linear_block(self,
modules,
neurons_in,
neurons_out):
modules.append(nn.Linear(neurons_in, neurons_out))
modules.append(nn.BatchNorm1d(neurons_out))
modules.append(nn.ReLU(True))
modules.append(nn.Dropout(self.dropout_rate))
return modules
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment