Skip to content

Instantly share code, notes, and snippets.

@arunmallya
Created June 20, 2017 17:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save arunmallya/02353769b31660e21a69bf6b5cd2faed to your computer and use it in GitHub Desktop.
Save arunmallya/02353769b31660e21a69bf6b5cd2faed to your computer and use it in GitHub Desktop.
Performs conv only on valid input images.
# Since convolution takes the most time, only do it on images with
# mask = 1. Note that masks.data.nonzero() is of size (N, 1).
# As a result, when expanding to 4 dims, we need to unsqueeze it twice.
selected_idx = Variable(masks.data.nonzero().unsqueeze(2).unsqueeze(3).repeat(
1, images.size(1), images.size(2), images.size(3)))
selected_images = torch.gather(images, 0, selected_idx)
# Get image features from CNN and linear layer rnn_emb.
some_im_feats = self.rnn_emb(self.cnn(selected_images).squeeze())
# Insert images features back into where they came from.
im_feats = Variable(torch.FloatTensor(
images.size(0), self.im_feat_size).zero_().type_as(images.data))
# Once again, masks is of size (N, 1), so no need to unsqueeze.
selected_feats_idx = Variable(masks.data.nonzero().repeat(
1, im_feats.size(1)).type_as(arguments.data))
im_feats.scatter_(0, selected_feats_idx, some_im_feats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment