Skip to content

Instantly share code, notes, and snippets.

@Tony363
Created July 31, 2023 15:01
Show Gist options
  • Save Tony363/3995d2c31459c2e841fe2fcc46c7749f to your computer and use it in GitHub Desktop.
Save Tony363/3995d2c31459c2e841fe2fcc46c7749f to your computer and use it in GitHub Desktop.
musiq_forward.py
def forward(self, mask_inputs, feat_dis_org_embed, feat_dis_scale_1_embed, feat_dis_scale_2_embed):
# feat_dis_org_embed: batch x (C=384) x (H=24) x (W=32)
# feat_dis_scale_1_embed: batch x (C=384) x (H=9) x (W=12)
# feat_dis_scale_2_embed: batch x (C=384) x (H=5) x (W=7)
# learnable scale embedding
scale_org_embed = repeat(self.scale_org_embedding, '() c () () -> b c h w', b=self.config.batch_size, h=24, w=32)
scale_1_embed = repeat(self.scale_1_embedding, '() c () () -> b c h w', b=self.config.batch_size, h=9, w=12)
scale_2_embed = repeat(self.scale_1_embedding, '() c () () -> b c h w', b=self.config.batch_size, h=5, w=7)
feat_dis_org_embed += scale_org_embed
feat_dis_scale_1_embed += scale_1_embed
feat_dis_scale_2_embed += scale_2_embed
# learnable 2D spatial embedding
# original scale
b, c, h, w = feat_dis_org_embed.size()
spatial_org_embed = torch.zeros(1, self.config.d_hidn, h, w).to(self.config.device)
for i in range(h):
for j in range(w):
t_i = int((i/h)*self.config.Grid)
t_j = int((j/w)*self.config.Grid)
spatial_org_embed[:, :, i, j] = self.pos_embedding[:, t_i, t_j, :]
spatial_org_embed = repeat(spatial_org_embed, '() c h w -> b c h w', b=self.config.batch_size)
# scale 1
b, c, h, w = feat_dis_scale_1_embed.size()
spatial_scale_1_embed = torch.zeros(1, self.config.d_hidn, h, w).to(self.config.device)
for i in range(h):
for j in range(w):
t_i = int((i/h)*self.config.Grid)
t_j = int((j/w)*self.config.Grid)
spatial_scale_1_embed[:, :, i, j] = self.pos_embedding[:, t_i, t_j, :]
spatial_scale_1_embed = repeat(spatial_scale_1_embed, '() c h w -> b c h w', b=self.config.batch_size)
# scale 2
b, c, h, w = feat_dis_scale_2_embed.size()
spatial_scale_2_embed = torch.zeros(1, self.config.d_hidn, h , w).to(self.config.device)
for i in range(h):
for j in range(w):
t_i = int((i/h)*self.config.Grid)
t_j = int((j/w)*self.config.Grid)
spatial_scale_2_embed[:, :, i, j] = self.pos_embedding[:, t_i, t_j, :]
spatial_scale_2_embed = repeat(spatial_scale_2_embed, '() c h w -> b c h w', b=self.config.batch_size)
feat_dis_org_embed += spatial_org_embed
feat_dis_scale_1_embed += spatial_scale_1_embed
feat_dis_scale_2_embed += spatial_scale_2_embed
# batch x (C=384) x (H=24) x (W=32) -> batch x (H*W=24*32) x (C=384)
b, c, h, w = feat_dis_org_embed.size()
feat_dis_org_embed = torch.reshape(feat_dis_org_embed, (b, c, h*w))
feat_dis_org_embed = feat_dis_org_embed.permute((0, 2, 1))
# batch x (C=384) x (H=12) x (W=9) -> batch x (H*W=12*9) x (C=384)
b, c, h, w = feat_dis_scale_1_embed.size()
feat_dis_scale_1_embed = torch.reshape(feat_dis_scale_1_embed, (b, c, h*w))
feat_dis_scale_1_embed = feat_dis_scale_1_embed.permute((0, 2, 1))
# batch x (C=384) x (H=7) x (W=5) -> batch x (H*W=7*5) x (C=384)
b, c, h, w = feat_dis_scale_2_embed.size()
feat_dis_scale_2_embed = torch.reshape(feat_dis_scale_2_embed, (b, c, h*w))
feat_dis_scale_2_embed = feat_dis_scale_2_embed.permute((0, 2, 1))
# concat scale embedding
inputs_embed = torch.cat((feat_dis_org_embed, feat_dis_scale_1_embed, feat_dis_scale_2_embed), dim=1)
# outputs: batch x (len_seq+1) x n_feat
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=self.config.batch_size)
x = torch.cat((cls_tokens, inputs_embed), dim=1)
# x += self.pos_embedding # positional embedding (learnable parameter)
outputs = self.dropout(x)
# (bs, n_enc_seq+1, n_enc_seq+1)
attn_mask = get_attn_pad_mask(mask_inputs, mask_inputs, self.config.i_pad)
attn_probs = []
for layer in self.layers:
# (bs, n_enc_seq+1, d_hidn), (bs, n_head, n_enc_seq+1, n_enc_seq+1)
outputs, attn_prob = layer(outputs, attn_mask)
attn_probs.append(attn_prob)
# (bs, n_enc_seq+1, d_hidn), [(bs, n_head, n_enc_seq+1, n_enc_seq+1)]
return outputs, attn_probs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment