Skip to content

Instantly share code, notes, and snippets.

@rishab-sharma
Last active September 6, 2020 11:06
Show Gist options
  • Save rishab-sharma/be13edb9bc71707abf603139a6ca8a99 to your computer and use it in GitHub Desktop.
Save rishab-sharma/be13edb9bc71707abf603139a6ca8a99 to your computer and use it in GitHub Desktop.
Listening to the pixels - ASN
import torch
import torch.nn as nn
import torch.nn.functional as F
# Code from Hang Zhao (@hangzhaomit)
class InnerProd(nn.Module):
def __init__(self, fc_dim):
super(InnerProd, self).__init__()
self.scale = nn.Parameter(torch.ones(fc_dim))
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, feat_img, feat_sound):
sound_size = feat_sound.size()
B, C = sound_size[0], sound_size[1]
feat_img = feat_img.view(B, 1, C)
z = torch.bmm(feat_img * self.scale, feat_sound.view(B, C, -1)) \
.view(B, 1, *sound_size[2:])
z = z + self.bias
return z
def forward_nosum(self, feat_img, feat_sound):
(B, C, H, W) = feat_sound.size()
feat_img = feat_img.view(B, C)
z = (feat_img * self.scale).view(B, C, 1, 1) * feat_sound
z = z + self.bias
return z
# inference purposes
def forward_pixelwise(self, feats_img, feat_sound):
(B, C, HI, WI) = feats_img.size()
(B, C, HS, WS) = feat_sound.size()
feats_img = feats_img.view(B, C, HI*WI)
feats_img = feats_img.transpose(1, 2)
feat_sound = feat_sound.view(B, C, HS * WS)
z = torch.bmm(feats_img * self.scale, feat_sound) \
.view(B, HI, WI, HS, WS)
z = z + self.bias
return z
class Bias(nn.Module):
def __init__(self):
super(Bias, self).__init__()
self.bias = nn.Parameter(torch.zeros(1))
# self.bias = nn.Parameter(-torch.ones(1))
def forward(self, feat_img, feat_sound):
(B, C, H, W) = feat_sound.size()
feat_img = feat_img.view(B, 1, C)
z = torch.bmm(feat_img, feat_sound.view(B, C, H * W)).view(B, 1, H, W)
z = z + self.bias
return z
def forward_nosum(self, feat_img, feat_sound):
(B, C, H, W) = feat_sound.size()
z = feat_img.view(B, C, 1, 1) * feat_sound
z = z + self.bias
return z
# inference purposes
def forward_pixelwise(self, feats_img, feat_sound):
(B, C, HI, WI) = feats_img.size()
(B, C, HS, WS) = feat_sound.size()
feats_img = feats_img.view(B, C, HI*WI)
feats_img = feats_img.transpose(1, 2)
feat_sound = feat_sound.view(B, C, HS * WS)
z = torch.bmm(feats_img, feat_sound) \
.view(B, HI, WI, HS, WS)
z = z + self.bias
return z
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment