Skip to content

Instantly share code, notes, and snippets.

@rbrigden rbrigden/multi_softmax.py
Last active Jul 20, 2018

Embed
What would you like to do?
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class MLP(nn.Module):
def __init__(self, input_size, feature_categories):
super(MLP, self).__init__()
self.feature_categories = feature_categories
out_size = int(np.sum(feature_categories))
# Two fully connected layers
self.fc1 = nn.Linear(input_size, 64)
self.fc2 = nn.Linear(64, out_size)
def forward(self, x):
z = F.relu(self.fc1(x))
start = 0
outputs = []
# Take a softmax over each categorical variable
for csize in self.feature_categories:
category_logits = F.log_softmax(z[:, start:start+csize], dim=1)
outputs.append(category_logits)
return outputs
def random_labels(batch_size, feature_categories):
# Create random batch of index labels for each catgeory
return [torch.randint(fc, (batch_size,)).type(torch.LongTensor) for fc in feature_categories]
if __name__ == "__main__":
input_size = 784
batch_size = 128
# Let's say we have 4 variables with respective number of
# categories 5, 4, 7, 9.
feature_categories = [5, 4, 7, 9]
# random data
data = torch.randn(batch_size, input_size)
# random labels
label_idxs = random_labels(batch_size, feature_categories)
# init the model
model = MLP(input_size, feature_categories)
# forward_pass
out = model(data)
# Compute loss
losses = [F.nll_loss(feature, label) for feature, label in zip(out, label_idxs)]
net_loss = sum(losses)
# backprop
net_loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.