Skip to content

Instantly share code, notes, and snippets.

# kendricktan/capsule_networks.py Last active Dec 1, 2019

Clean Code for Capsule Networks
 """ Dynamic Routing Between Capsules https://arxiv.org/abs/1710.09829 """ import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as transforms import numpy as np from torch.autograd import Variable from torchvision.datasets.mnist import MNIST from tqdm import tqdm def index_to_one_hot(index_tensor, num_classes=10): """ Converts index value to one hot vector. e.g. [2, 5] (with 10 classes) becomes: [ [0 0 1 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] ] """ index_tensor = index_tensor.long() return torch.eye(num_classes).index_select(dim=0, index=index_tensor) def squash_vector(tensor, dim=-1): """ Non-linear 'squashing' to ensure short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1. """ squared_norm = (tensor**2).sum(dim=dim, keepdim=True) scale = squared_norm / (1 + squared_norm) return scale * tensor / torch.sqrt(squared_norm) def softmax(input, dim=1): """ Apply softmax to specific dimensions. Not released on PyTorch stable yet as of November 6th 2017 https://github.com/pytorch/pytorch/issues/3235 """ transposed_input = input.transpose(dim, len(input.size()) - 1) softmaxed_output = F.softmax( transposed_input.contiguous().view(-1, transposed_input.size(-1))) return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1) class CapsuleLayer(nn.Module): def __init__(self, num_capsules, num_routes, in_channels, out_channels, kernel_size=None, stride=None, num_iterations=3): super().__init__() self.num_routes = num_routes self.num_iterations = num_iterations self.num_capsules = num_capsules if num_routes != -1: self.route_weights = nn.Parameter( torch.randn(num_capsules, num_routes, in_channels, out_channels) ) else: self.capsules = nn.ModuleList( [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in range(num_capsules) ] ) def forward(self, x): # If routing is defined if self.num_routes != -1: priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :] logits = Variable(torch.zeros(priors.size())).cuda() # Routing algorithm for i in range(self.num_iterations): probs = softmax(logits, dim=2) outputs = squash_vector( probs * priors).sum(dim=2, keepdim=True) if i != self.num_iterations - 1: delta_logits = (priors * outputs).sum(dim=-1, keepdim=True) logits = logits + delta_logits else: outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules] outputs = torch.cat(outputs, dim=-1) outputs = squash_vector(outputs) return outputs class MarginLoss(nn.Module): def __init__(self): super().__init__() # Reconstruction as regularization self.reconstruction_loss = nn.MSELoss(size_average=False) def forward(self, images, labels, classes, reconstructions): left = F.relu(0.9 - classes, inplace=True) ** 2 right = F.relu(classes - 0.1, inplace=True) ** 2 margin_loss = labels * left + 0.5 * (1. - labels) * right margin_loss = margin_loss.sum() reconstruction_loss = self.reconstruction_loss(reconstructions, images) return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0) class CapsuleNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d( in_channels=1, out_channels=256, kernel_size=9, stride=1) self.primary_capsules = CapsuleLayer( 8, -1, 256, 32, kernel_size=9, stride=2) # 10 is the number of classes self.digit_capsules = CapsuleLayer(10, 32 * 6 * 6, 8, 16) self.decoder = nn.Sequential( nn.Linear(16 * 10, 512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, 784), nn.Sigmoid() ) def forward(self, x, y=None): x = F.relu(self.conv1(x), inplace=True) x = self.primary_capsules(x) x = self.digit_capsules(x).squeeze().transpose(0, 1) classes = (x ** 2).sum(dim=-1) ** 0.5 classes = F.softmax(classes) if y is None: # In all batches, get the most active capsule _, max_length_indices = classes.max(dim=1) y = Variable(torch.eye(10)).cuda().index_select( dim=0, index=max_length_indices.data) reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) return classes, reconstructions if __name__ == '__main__': # Globals CUDA = True EPOCH = 10 # Model model = CapsuleNet() if CUDA: model.cuda() optimizer = optim.Adam(model.parameters()) margin_loss = MarginLoss() train_loader = torch.utils.data.DataLoader( MNIST(root='/tmp', download=True, train=True, transform=transforms.ToTensor()), batch_size=8, shuffle=True) test_loader = torch.utils.data.DataLoader( MNIST(root='/tmp', download=True, train=False, transform=transforms.ToTensor()), batch_size=8, shuffle=True) for e in range(10): # Training train_loss = 0 model.train() for idx, (img, target) in enumerate(tqdm(train_loader, desc='Training')): img = Variable(img) target = Variable(index_to_one_hot(target)) if CUDA: img = img.cuda() target = target.cuda() optimizer.zero_grad() classes, reconstructions = model(img, target) loss = margin_loss(img, target, classes, reconstructions) loss.backward() train_loss += loss.data.cpu()[0] optimizer.step() print('Training:, Avg Loss: {:.4f}'.format(train_loss)) # # Testing correct = 0 test_loss = 0 model.eval() for idx, (img, target) in enumerate(tqdm(test_loader, desc='test set')): img = Variable(img) target_index = target target = Variable(index_to_one_hot(target)) if CUDA: img = img.cuda() target = target.cuda() classes, reconstructions = model(img, target) test_loss += margin_loss(img, target, classes, reconstructions).data.cpu() # Get index of the max log-probability pred = classes.data.max(1, keepdim=True)[1].cpu() correct += pred.eq(target_index.view_as(pred)).cpu().sum() test_loss /= len(test_loader.dataset) correct = 100. * correct / len(test_loader.dataset) print('Test Set: Avg Loss: {:.4f}, Accuracy: {:.4f}'.format( test_loss[0], correct))

### Atcold commented Nov 16, 2017 • edited

 Traceback (most recent call last): File "capsule_networks.py", line 230, in correct += pred.eq(target.data.view_as(pred)).cpu().sum() File "/home/atcold/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 198, in view_as return self.view(tensor.size()) RuntimeError: invalid argument 2: size '[8 x 1]' is invalid for input of with 80 elements at /opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/TH/THStorage.c:41  Also, tqdm and print make a mess on screen.
Owner Author

### kendricktan commented Nov 16, 2017

 @Atcold Ooops, my bad. It appears that I've pasted in an outdated version. I've updated the gist now and removed redundancy of tqdm and print.

### Atcold commented Nov 17, 2017

 Very well, @kendricktan. Two more remarks. You can (1) reintroduce tqdm in the training cycle (as long as you don't print the loss on screen), (2) factor out the feed-forward pass and loss evaluation, which are shared by both training and testing procedures. Furthermore, I'd recommend zeroing the gradient after the forward pass, and just before the backward pass, to reduce confusion.
Owner Author

### kendricktan commented Nov 17, 2017

 @Atcold done for your remark 1.. As for the 2. I personally think that the state of optimizer should be made explicit (zero'd before anything happens) before anything else happens. Thanks for the feedback 👍

### balassbals commented Nov 17, 2017 • edited

 I have a doubt. logits in line num 88 gets the size 10 x 128 x 1152 x 1 x 16. But softmax is done with repect to dim 2 . Should it not be with respect to dim 0 since we have 10 classes. Can you clarify? (assuming batch size is 128)

### Atcold commented Nov 17, 2017 • edited

 @balassbals, there are a total of (6 × 6 × 32) 8D capsules u, which provide their prediction vectors \hat u. Each capsule input s is the weighted average of the corresponding \hat u. The weighting coefficient c are given by the softmax over the logits b, which are as many as the number of capsules in the layer below, i.e. 6 × 6 × 32. Therefore, it is correct to run the softmax on the 3rd dimension (i.e. dimension number 2). Please, let me know if it is not clear. @kendricktan, the optimiser is part of the back-propagation algorithm, which starts aftre the forward pass. This is why I would recommend not mixing the two things. I have students who confuse the two... One last thing, this code does not run when CUDA = False at line 166. Instead of cuda() use type_as(other_tensor).

### balassbals commented Nov 17, 2017

 @Atcold, I understand what you say. But still I'm confused since the paper says that coupling coeffs between capsule i and all the capsules in the layer above sum to 1 and equation 3 in paper supports this statement. But from what you say my understanding is that the coeffs between all capsules in layer l and capsule j in layer l + 1 sum to one. Can you clarify?

### Atcold commented Nov 17, 2017

 @balassbals, you are correct. Today I gave a speech at NYU, about this paper, and people pointed out that the softmax is done across the fist dimension (i.e. dimension number 0). I missed this the first time I read the paper. My bad. So you are correct, there is a mistake in this implementation. @kendricktan did you follow the conversation? If so, please fix.

### balassbals commented Nov 18, 2017

 @Atcold, But when I do across dim 0(10 classes), I dont get the expected results. Another implementation I saw in Pytorch uses F.softmax wrongly. Actually I implemented it myself first but I'm not getting the results. So I'm looking for some working version in Pytorch.

### Atcold commented Nov 20, 2017

 Also, why is there a softmax() at line L152? This should simply be the capsule's norm! Correct?

### pqn commented May 12, 2018

 @balassbals I have not found any working PyTorch implementations that softmax across the 10 classes (only across the 1152 routes, which does not match the paper). Have you discovered anything since?

### afmsaif commented Jun 13, 2018

 hello, i have some experience about capnet written in tensorflow but i have no idea about pytorch. can you help me? i want to input data which has size of (224,224,3) and target will be binary 0 or 1 so for this kind of data what kind of modification i have to make? thanks in advance.
to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.