Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Hierarchical Softmax CNN Classification
import torch
import torch.nn as nn
import torch.nn.init as init
dropout_prob = 0.5
class FlatCnnLayer(nn.Module):
def __init__(self, embedding_size, sequence_length, filter_sizes=[3, 4, 5], out_channels=128):
super(FlatCnnLayer, self).__init__()
self.embedding_size = embedding_size
self.sequence_length = sequence_length
self.out_channels = out_channels
self.filter_layers = nn.ModuleList()
for filter_size in filter_sizes:
self.filter_layers.append(self._make_filter_layer(filter_size))
self.dropout = nn.Dropout(p=dropout_prob)
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.normal(m.weight, mean=0, std=0.1)
init.constant(m.bias, 0.1)
def forward(self, x):
pools = []
for filter_layer in self.filter_layers:
pools.append(filter_layer(x))
x = torch.cat(pools, dim=1)
x = x.view(x.size()[0], -1)
x = self.dropout(x)
return x
def _make_filter_layer(self, filter_size):
return nn.Sequential(
nn.Conv2d(1, self.out_channels, (filter_size, self.embedding_size)),
nn.ReLU(inplace=True),
nn.MaxPool2d((self.sequence_length - filter_size + 1, 1), stride=1)
)
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.init as init
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from FlatCnnLayer import FlatCnnLayer
from TreeTools import TreeTools
import multiprocessing
import numpy as np
batch_size = 128
n_epochs = 200
display_step = 5
N_WORKERS = max(1, multiprocessing.cpu_count() - 1)
class HierarchicalTextClassifyCnnNet(nn.Module):
def __init__(self, embedding_size, sequence_length, tree, filter_sizes=[3, 4, 5], out_channels=128):
super(HierarchicalTextClassifyCnnNet, self).__init__()
self._tree_tools = TreeTools()
self.tree = tree
# create a weight matrix and bias vector for each node in the tree
self.fc = nn.ModuleList([nn.Linear(out_channels * len(filter_sizes), len(subtree[1])) for subtree in
self._tree_tools.get_subtrees(tree)])
self.value_to_path_and_nodes_dict = {}
for path, value in self._tree_tools.get_paths(tree):
nodes = self._tree_tools.get_nodes(tree, path)
self.value_to_path_and_nodes_dict[value] = path, nodes
self.flat_layer = FlatCnnLayer(embedding_size, sequence_length, filter_sizes=filter_sizes,
out_channels=out_channels)
self.features = nn.Sequential(self.flat_layer)
for m in self.modules():
if isinstance(m, nn.Linear):
init.xavier_uniform(m.weight, gain=np.sqrt(2.0))
init.constant(m.bias, 0.1)
def forward(self, inputs, targets):
features = self.features(inputs)
predicts = map(self._get_predicts, features, targets)
losses = map(self._get_loss, predicts, targets)
return losses, predicts
def _get_loss(self, predicts, label):
path, _ = self.value_to_path_and_nodes_dict[int(label.data[0])]
criterion = nn.CrossEntropyLoss()
if torch.cuda.is_available:
criterion = criterion.cuda()
def f(predict, p):
p = torch.LongTensor([p])
# convert to cuda tensors if cuda flag is true
if torch.cuda.is_available:
p = p.cuda()
p = Variable(p)
return criterion(predict.unsqueeze(0), p)
loss = map(f, predicts, path)
return torch.sum(torch.cat(loss))
def _get_predicts(self, feature, label):
_, nodes = self.value_to_path_and_nodes_dict[int(label.data[0])]
predicts = map(lambda n: self.fc[n](feature), nodes)
return predicts
def fit(model, data, save_path):
criterion = nn.CrossEntropyLoss()
if torch.cuda.is_available():
model, criterion = model.cuda(), criterion.cuda()
# for param in list(model.parameters()):
# print(type(param.data), param.size())
# optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=0.1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
x_train, x_test = torch.from_numpy(data['X_train']).float(), torch.from_numpy(data['X_test']).float()
y_train, y_test = torch.from_numpy(data['Y_train']).int(), torch.from_numpy(data['Y_test']).int()
train_set = TensorDataset(x_train, y_train)
test_set = TensorDataset(x_test, y_test)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=N_WORKERS,
pin_memory=torch.cuda.is_available())
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=N_WORKERS)
model.train()
for epoch in range(1, n_epochs + 1): # loop over the dataset multiple times
acc_loss = 0.0
for inputs, labels in iter(train_loader):
# convert to cuda tensors if cuda flag is true
if torch.cuda.is_available:
inputs, labels = inputs.cuda(), labels.cuda()
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
losses, _ = model(inputs, labels)
loss = torch.mean(torch.cat(losses, dim=0))
acc_loss += loss.data[0]
loss.backward()
optimizer.step()
# print statistics
if epoch % display_step == 0 or epoch == 1:
print('[%3d] loss: %.5f' %
(epoch, acc_loss / len(train_set.data_tensor)))
print('\rFinished Training\n')
model.eval()
nb_test_corrects, nb_test_samples = 0, 0
for inputs, labels in iter(test_loader):
# convert to cuda tensors if cuda flag is true
if torch.cuda.is_available:
inputs, labels = inputs.cuda(), labels.cuda()
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)
# forward + backward + optimize
_, predicts = model(inputs, labels)
nb_test_samples += labels.size(0)
for predicted, label in zip(predicts, labels):
nb_test_corrects += _check_predicts(model, predicted, label)
print ('Accuracy of the network {:.2f}% ({:d} / {:d})'.format(
100 * nb_test_corrects / nb_test_samples,
nb_test_corrects,
nb_test_samples)
)
torch.save(model.flat_layer.state_dict(), save_path)
def _check_predicts(model, predicts, label):
path, _ = model.value_to_path_and_nodes_dict[int(label.data[0])]
for predict, p in zip(predicts, path):
if np.argmax(predict.data) != p:
return 0
return 1
# (value, subtrees)
class TreeTools:
def __init__(self):
# memoization for _count_nodes functions
self._count_nodes_dict = {}
# Return tree is leave or not
@staticmethod
def _is_not_leave(tree):
return type(tree[1]) == list
def get_subtrees(self, tree):
yield tree
if self._is_not_leave(tree):
for subtree in tree[1]:
if self._is_not_leave(subtree):
for x in self.get_subtrees(subtree):
yield x
# Returns pairs of paths and values of a tree
def get_paths(self, tree):
for i, subtree in enumerate(tree[1]):
yield [i], subtree[0]
if self._is_not_leave(subtree):
for path, value in self.get_paths(subtree):
yield [i] + path, value
# Returns the number of nodes in a tree (not including root)
def count_nodes(self, tree):
return self._count_nodes(tree[1])
def _count_nodes(self, branches):
if id(branches) in self._count_nodes_dict:
return self._count_nodes_dict[id(branches)]
size = 0
for node in branches:
if self._is_not_leave(node):
size += 1 + self._count_nodes(node[1])
self._count_nodes_dict[id(branches)] = size
return size
# Returns all the nodes in a path
def get_nodes(self, tree, path):
next_node = 0
nodes = []
for decision in path:
nodes.append(next_node)
if not self._is_not_leave(tree):
break
next_node += 1 + self._count_nodes(tree[1][:decision])
tree = tree[1][decision]
return nodes
@mustafa-qamaruddin
Copy link

mustafa-qamaruddin commented Feb 8, 2018

It is not quite clear to me why do you need a separate layer for each internal node in the target words partitioning tree:

https://gist.github.com/paduvi/588bc95c13e73c1e5110d4308e6291ab#file-hierarchicaltextclassifycnnnet-py-L25

# create a weight matrix and bias vector for each node in the tree
        self.fc = nn.ModuleList([nn.Linear(out_channels * len(filter_sizes), len(subtree[1])) for subtree in
                                 self._tree_tools.get_subtrees(tree)])

I suppose it meant by internal node presentation v prime that each internal node has one neuron only in the neural network layer. That should be:
Hidden Layer > Internal Nodes Layer > Target Words Layer

How does that sound?

@keshav47
Copy link

keshav47 commented Jan 31, 2020

Could you please provide a sample data file or your input data format in order to make the tree structure and code more understandable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment