Skip to content

Instantly share code, notes, and snippets.

/model.py Secret

Created February 22, 2018 18:35
Show Gist options
  • Save anonymous/2701c0964712e0a7fcce64ea752e391a to your computer and use it in GitHub Desktop.
Save anonymous/2701c0964712e0a7fcce64ea752e391a to your computer and use it in GitHub Desktop.
from pth_nms import pth_nms
# class Net(nn.Module):
# ...
# remove the existing forward and put these two functions there instead
def forward(self, v, b, q, q_len):
q = self.text(q, list(q_len.data))
v = v / (v.norm(p=2, dim=1, keepdim=True) + 1e-12).expand_as(v)
a = self.attention(v, q)
v = apply_attention(v, a)
a1 = a.sum(dim=1).view(a.size(0), -1, 1)
count = self.nms_feature(a1, b)
answer = self.classifier(v, q, count)
return answer
def nms_feature(self, a, b):
b = torch.cat([b.transpose(1, 2), a], dim=2)
# b is now (n, b, 4)
l = torch.zeros(a.size(0))
for i, (sample, att) in enumerate(zip(b.data, a.data)):
indices = pth_nms(sample, 0.5)
taken = att.squeeze(dim=1).gather(0, indices)
num_elements = (taken > 0).sum()
l[i] = num_elements
l = Variable(l.unsqueeze(dim=1).cuda(async=True))
return self.counter.to_one_hot(l)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment