Skip to content

Instantly share code, notes, and snippets.

Last active August 17, 2021 17:12
Show Gist options
  • Save kendricktan/9a776ec6322abaaf03cc9befd35508d4 to your computer and use it in GitHub Desktop.
Save kendricktan/9a776ec6322abaaf03cc9befd35508d4 to your computer and use it in GitHub Desktop.
Clean Code for Capsule Networks
Dynamic Routing Between Capsules
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
from torch.autograd import Variable
from torchvision.datasets.mnist import MNIST
from tqdm import tqdm
def index_to_one_hot(index_tensor, num_classes=10):
Converts index value to one hot vector.
e.g. [2, 5] (with 10 classes) becomes:
[0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0]
index_tensor = index_tensor.long()
return torch.eye(num_classes).index_select(dim=0, index=index_tensor)
def squash_vector(tensor, dim=-1):
Non-linear 'squashing' to ensure short vectors get shrunk
to almost zero length and long vectors get shrunk to a
length slightly below 1.
squared_norm = (tensor**2).sum(dim=dim, keepdim=True)
scale = squared_norm / (1 + squared_norm)
return scale * tensor / torch.sqrt(squared_norm)
def softmax(input, dim=1):
Apply softmax to specific dimensions. Not released on PyTorch stable yet
as of November 6th 2017
transposed_input = input.transpose(dim, len(input.size()) - 1)
softmaxed_output = F.softmax(
transposed_input.contiguous().view(-1, transposed_input.size(-1)))
return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)
class CapsuleLayer(nn.Module):
def __init__(self, num_capsules, num_routes, in_channels, out_channels,
kernel_size=None, stride=None, num_iterations=3):
self.num_routes = num_routes
self.num_iterations = num_iterations
self.num_capsules = num_capsules
if num_routes != -1:
self.route_weights = nn.Parameter(
torch.randn(num_capsules, num_routes,
in_channels, out_channels)
self.capsules = nn.ModuleList(
for _ in range(num_capsules)
def forward(self, x):
# If routing is defined
if self.num_routes != -1:
priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
logits = Variable(torch.zeros(priors.size())).cuda()
# Routing algorithm
for i in range(self.num_iterations):
probs = softmax(logits, dim=2)
outputs = squash_vector(
probs * priors).sum(dim=2, keepdim=True)
if i != self.num_iterations - 1:
delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
logits = logits + delta_logits
outputs = [capsule(x).view(x.size(0), -1, 1)
for capsule in self.capsules]
outputs =, dim=-1)
outputs = squash_vector(outputs)
return outputs
class MarginLoss(nn.Module):
def __init__(self):
# Reconstruction as regularization
self.reconstruction_loss = nn.MSELoss(size_average=False)
def forward(self, images, labels, classes, reconstructions):
left = F.relu(0.9 - classes, inplace=True) ** 2
right = F.relu(classes - 0.1, inplace=True) ** 2
margin_loss = labels * left + 0.5 * (1. - labels) * right
margin_loss = margin_loss.sum()
reconstruction_loss = self.reconstruction_loss(reconstructions, images)
return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)
class CapsuleNet(nn.Module):
def __init__(self):
self.conv1 = nn.Conv2d(
in_channels=1, out_channels=256, kernel_size=9, stride=1)
self.primary_capsules = CapsuleLayer(
8, -1, 256, 32, kernel_size=9, stride=2)
# 10 is the number of classes
self.digit_capsules = CapsuleLayer(10, 32 * 6 * 6, 8, 16)
self.decoder = nn.Sequential(
nn.Linear(16 * 10, 512),
nn.Linear(512, 1024),
nn.Linear(1024, 784),
def forward(self, x, y=None):
x = F.relu(self.conv1(x), inplace=True)
x = self.primary_capsules(x)
x = self.digit_capsules(x).squeeze().transpose(0, 1)
classes = (x ** 2).sum(dim=-1) ** 0.5
classes = F.softmax(classes)
if y is None:
# In all batches, get the most active capsule
_, max_length_indices = classes.max(dim=1)
y = Variable(torch.eye(10)).cuda().index_select(
reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
return classes, reconstructions
if __name__ == '__main__':
# Globals
CUDA = True
EPOCH = 10
# Model
model = CapsuleNet()
if CUDA:
optimizer = optim.Adam(model.parameters())
margin_loss = MarginLoss()
train_loader =
MNIST(root='/tmp', download=True, train=True,
batch_size=8, shuffle=True)
test_loader =
MNIST(root='/tmp', download=True, train=False,
batch_size=8, shuffle=True)
for e in range(10):
# Training
train_loss = 0
for idx, (img, target) in enumerate(tqdm(train_loader, desc='Training')):
img = Variable(img)
target = Variable(index_to_one_hot(target))
if CUDA:
img = img.cuda()
target = target.cuda()
classes, reconstructions = model(img, target)
loss = margin_loss(img, target, classes, reconstructions)
train_loss +=[0]
print('Training:, Avg Loss: {:.4f}'.format(train_loss))
# # Testing
correct = 0
test_loss = 0
for idx, (img, target) in enumerate(tqdm(test_loader, desc='test set')):
img = Variable(img)
target_index = target
target = Variable(index_to_one_hot(target))
if CUDA:
img = img.cuda()
target = target.cuda()
classes, reconstructions = model(img, target)
test_loss += margin_loss(img, target, classes, reconstructions).data.cpu()
# Get index of the max log-probability
pred =, keepdim=True)[1].cpu()
correct += pred.eq(target_index.view_as(pred)).cpu().sum()
test_loss /= len(test_loader.dataset)
correct = 100. * correct / len(test_loader.dataset)
print('Test Set: Avg Loss: {:.4f}, Accuracy: {:.4f}'.format(
test_loss[0], correct))
Copy link

balassbals commented Nov 17, 2017

I have a doubt. logits in line num 88 gets the size 10 x 128 x 1152 x 1 x 16. But softmax is done with repect to dim 2 . Should it not be with respect to dim 0 since we have 10 classes. Can you clarify? (assuming batch size is 128)

Copy link

Atcold commented Nov 17, 2017

@balassbals, there are a total of (6 × 6 × 32) 8D capsules u, which provide their prediction vectors \hat u. Each capsule input s is the weighted average of the corresponding \hat u. The weighting coefficient c are given by the softmax over the logits b, which are as many as the number of capsules in the layer below, i.e. 6 × 6 × 32. Therefore, it is correct to run the softmax on the 3rd dimension (i.e. dimension number 2). Please, let me know if it is not clear.

@kendricktan, the optimiser is part of the back-propagation algorithm, which starts aftre the forward pass. This is why I would recommend not mixing the two things. I have students who confuse the two...
One last thing, this code does not run when CUDA = False at line 166. Instead of cuda() use type_as(other_tensor).

Copy link

@Atcold, I understand what you say. But still I'm confused since the paper says that coupling coeffs between capsule i and all the capsules in the layer above sum to 1 and equation 3 in paper supports this statement. But from what you say my understanding is that the coeffs between all capsules in layer l and capsule j in layer l + 1 sum to one. Can you clarify?

Copy link

Atcold commented Nov 17, 2017

@balassbals, you are correct. Today I gave a speech at NYU, about this paper, and people pointed out that the softmax is done across the fist dimension (i.e. dimension number 0). I missed this the first time I read the paper. My bad. So you are correct, there is a mistake in this implementation.
@kendricktan did you follow the conversation? If so, please fix.

Copy link

@Atcold, But when I do across dim 0(10 classes), I dont get the expected results. Another implementation I saw in Pytorch uses F.softmax wrongly. Actually I implemented it myself first but I'm not getting the results. So I'm looking for some working version in Pytorch.

Copy link

Atcold commented Nov 20, 2017

Also, why is there a softmax() at line L152? This should simply be the capsule's norm! Correct?

Copy link

pqn commented May 12, 2018

@balassbals I have not found any working PyTorch implementations that softmax across the 10 classes (only across the 1152 routes, which does not match the paper). Have you discovered anything since?

Copy link

afmsaif commented Jun 13, 2018

i have some experience about capnet written in tensorflow but i have no idea about pytorch. can you help me?
i want to input data which has size of (224,224,3) and target will be binary 0 or 1 so for this kind of data what kind of modification i have to make?
thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment