Last active Aug 17, 2021
Clean Code for Capsule Networks
Atcold commented Nov 16, 2017 • edited

Traceback (most recent call last):
File "capsule_networks.py", line 230, in <module>
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
File "/home/atcold/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 198, in view_as
return self.view(tensor.size())
RuntimeError: invalid argument 2: size '[8 x 1]' is invalid for input of with 80 elements at /opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/TH/THStorage.c:41


Also, tqdm and print make a mess on screen.

kendricktan commented Nov 16, 2017

@Atcold Ooops, my bad. It appears that I've pasted in an outdated version. I've updated the gist now and removed redundancy of tqdm and print.

Atcold commented Nov 17, 2017

Very well, @kendricktan. Two more remarks.
You can (1) reintroduce tqdm in the training cycle (as long as you don't print the loss on screen), (2) factor out the feed-forward pass and loss evaluation, which are shared by both training and testing procedures. Furthermore, I'd recommend zeroing the gradient after the forward pass, and just before the backward pass, to reduce confusion.

kendricktan commented Nov 17, 2017

@Atcold done for your remark 1..

As for the 2. I personally think that the state of optimizer should be made explicit (zero'd before anything happens) before anything else happens. Thanks for the feedback 👍

balassbals commented Nov 17, 2017 • edited

I have a doubt. logits in line num 88 gets the size 10 x 128 x 1152 x 1 x 16. But softmax is done with repect to dim 2 . Should it not be with respect to dim 0 since we have 10 classes. Can you clarify? (assuming batch size is 128)

Atcold commented Nov 17, 2017 • edited

@balassbals, there are a total of (6 × 6 × 32) 8D capsules u, which provide their prediction vectors \hat u. Each capsule input s is the weighted average of the corresponding \hat u. The weighting coefficient c are given by the softmax over the logits b, which are as many as the number of capsules in the layer below, i.e. 6 × 6 × 32. Therefore, it is correct to run the softmax on the 3rd dimension (i.e. dimension number 2). Please, let me know if it is not clear.

@kendricktan, the optimiser is part of the back-propagation algorithm, which starts aftre the forward pass. This is why I would recommend not mixing the two things. I have students who confuse the two...
One last thing, this code does not run when CUDA = False at line 166. Instead of cuda() use type_as(other_tensor).

balassbals commented Nov 17, 2017

@Atcold, I understand what you say. But still I'm confused since the paper says that coupling coeffs between capsule i and all the capsules in the layer above sum to 1 and equation 3 in paper supports this statement. But from what you say my understanding is that the coeffs between all capsules in layer l and capsule j in layer l + 1 sum to one. Can you clarify?

Atcold commented Nov 17, 2017

@balassbals, you are correct. Today I gave a speech at NYU, about this paper, and people pointed out that the softmax is done across the fist dimension (i.e. dimension number 0). I missed this the first time I read the paper. My bad. So you are correct, there is a mistake in this implementation.

balassbals commented Nov 18, 2017

@Atcold, But when I do across dim 0(10 classes), I dont get the expected results. Another implementation I saw in Pytorch uses F.softmax wrongly. Actually I implemented it myself first but I'm not getting the results. So I'm looking for some working version in Pytorch.

Atcold commented Nov 20, 2017

Also, why is there a softmax() at line L152? This should simply be the capsule's norm! Correct?

pqn commented May 12, 2018

@balassbals I have not found any working PyTorch implementations that softmax across the 10 classes (only across the 1152 routes, which does not match the paper). Have you discovered anything since?

afmsaif commented Jun 13, 2018

hello,
i have some experience about capnet written in tensorflow but i have no idea about pytorch. can you help me?
i want to input data which has size of (224,224,3) and target will be binary 0 or 1 so for this kind of data what kind of modification i have to make?