Skip to content

Instantly share code, notes, and snippets.

@hengck23
Last active September 17, 2020 09:06
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hengck23/6ebe1c75f8b3bcc953c0599ac76bad45 to your computer and use it in GitHub Desktop.
Save hengck23/6ebe1c75f8b3bcc953c0599ac76bad45 to your computer and use it in GitHub Desktop.
model
from common import *
num_class = 266
# https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/8d9999d72b282d2dc50a5b5f668dd91369f853c5/pytorch/models.py
# https://www.kaggle.com/hidehisaarai1213/introduction-to-sound-event-detection
# https://github.com/qiuqiangkong/sed_from_wekaly_labelled_data/blob/master/spectrogram_to_wave.py
class ConvBlock(nn.Module):
def __init__(self, in_channel, out_channel, pool_size=1):
super(ConvBlock, self).__init__()
self.pool_size = pool_size
self.conv1 = nn.Conv2d(in_channels=in_channel,
out_channels=out_channel,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channel,
out_channels=out_channel,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.bn2 = nn.BatchNorm2d(out_channel)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)),inplace=True)
x = F.relu(self.bn2(self.conv2(x)),inplace=True)
if self.pool_size !=1 :
x = F.avg_pool2d(x, kernel_size=self.pool_size)
return x
class AttentPool(nn.Module):
def __init__(self, in_channel, out_channel):
super(AttentPool, self).__init__()
#self.temperature = temperature
self.conv = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, padding=0, bias=False)
self.bn = nn.BatchNorm1d(out_channel)
def forward(self, x):
# x: (batch_size, C, num_time)
#norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
attention = self.bn(self.conv(x))
attention = torch.tanh(attention/10)*10
attention = torch.softmax(attention, dim=-1)
return attention
# type-1 roi
class ROI1(nn.Module):
def forward(self, x):
x = torch.mean(x, dim=3)
x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
x = x1 + x2
return x
#-----------------------------------------------------------------------------
# Cnn14_DecisionLevelAtt
class Net (nn.Module):
def load_pretrain(self, skip=[], is_print=True):
checkpoint = '/root/share1/kaggle/2020/birdsong/data/pretrain/Cnn14_DecisionLevelAtt_mAP0.425.pth'
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)['model']
for k in list(state_dict.keys()):
if any(s in k for s in [
'att_block','spectrogram_extractor','logmel_extractor','fc1',
]+skip): state_dict.pop(k, None)
self.load_state_dict(state_dict,strict=False) #True
def __init__(self):
super(Net, self).__init__()
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channel= 1, out_channel= 64, pool_size=2)
self.conv_block2 = ConvBlock(in_channel= 64, out_channel= 128, pool_size=2)
self.conv_block3 = ConvBlock(in_channel= 128, out_channel= 256, pool_size=2)
self.conv_block4 = ConvBlock(in_channel= 256, out_channel= 512, pool_size=2)
self.conv_block5 = ConvBlock(in_channel= 512, out_channel=1024, pool_size=2)
self.conv_block6 = ConvBlock(in_channel=1024, out_channel=2048, pool_size=1)
self.roi = nn.Sequential(
ROI1(),
nn.Conv1d(2048, 2048, 1, bias=False),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
)
self.probability = nn.Conv1d(2048, num_class, kernel_size=1, bias=False)
self.attention = AttentPool(2048, num_class)
def forward(self,x): #( batch_size, 1, num_freq, num_frame)
batch_size, c, num_freq, num_frame = x.shape
x = x.permute(0,2,1,3).contiguous() #(batch_size,num_freq, c, num_frame)
x = self.bn0(x)
x = x.permute(0,2,3,1).contiguous() #(batch_size,c, num_frame, num_freq)
x = self.conv_block1(x)
#x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x)
#x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x)
#x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x)
#x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x)
#x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x) #torch.Size([8, 2048, 15, 2])
#x = F.dropout(x, p=0.2, training=self.training)
#--------
# frame-wise rol
roi = self.roi(x)
#roi = F.dropout(roi, p=0.5, training=self.training)
#--------
probability = torch.sigmoid(self.probability(roi)) # frame-wise
attention = self.attention(roi) #torch.Size([8, 266]), torch.Size([8, 266, 15])
pool = torch.sum(attention * probability, dim=2) # clip-wise
return pool, attention
def binary_cross_entropy_with_logit_loss(logit,truth):
w = 1/num_class#0.5 #
batch_size = len(logit)
onehot = F.one_hot(truth, num_class).type(logit.dtype)
#loss = F.binary_cross_entropy_with_logits(logit,onehot)
num_p = onehot.sum().item()
num_n = (1-onehot).sum().item()
log_p = -F.logsigmoid( logit)
log_n = -F.logsigmoid(-logit)
loss_p = (onehot*log_p).sum()/num_p
loss_n = ((1-onehot)*log_n).sum()/num_n
loss = w*loss_p + (1-w)*loss_n
return loss
def binary_cross_entropy_loss(probability,truth):
w = 1/num_class#0.5 #
batch_size = len(probability)
onehot = F.one_hot(truth, num_class).type(truth.dtype)
#loss = F.binary_cross_entropy_with_logits(logit,onehot)
num_p = onehot.sum().item()
num_n = (1-onehot).sum().item()
probability = torch.clamp(probability,1e-5,1-1e-5)
log_p = -torch.log( probability)
log_n = -torch.log(1-probability)
loss_p = (onehot*log_p).sum()/num_p
loss_n = ((1-onehot)*log_n).sum()/num_n
loss = w*loss_p + (1-w)*loss_n
return loss
# check #################################################################
def run_check_net():
net = Net()
net.load_pretrain()
batch_size = 8
num_frame = 501
num_freq = 64
melspec = torch.randn((batch_size, 1, num_freq, num_frame))
probability, attention = net(melspec)
print('')
print('melspec: ',melspec.shape)
print('probability: ', probability.shape)
print('attention: ',attention.shape)
# main #################################################################
if __name__ == '__main__':
print( '%s: calling main function ... ' % os.path.basename(__file__))
run_check_net()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment