Skip to content

Instantly share code, notes, and snippets.

@hbredin
Last active April 20, 2024 09:36
Show Gist options
  • Save hbredin/c3bfa80f8f71181c2d499167d297f181 to your computer and use it in GitHub Desktop.
Save hbredin/c3bfa80f8f71181c2d499167d297f181 to your computer and use it in GitHub Desktop.
Implementation of SincConv layer in pytorch
#!/usr/bin/env python
# encoding: utf-8
# The MIT License (MIT)
# Copyright (c) 2018 CNRS
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class SincConv(nn.Module):
"""Sinc-based convolution
Parameters
----------
in_channels : `int`
Number of input channels. Must be 1.
out_channels : `int`
Number of filters.
kernel_size : `int`
Filter length.
sample_rate : `int`, optional
Sample rate. Defaults to 16000.
Usage
-----
See `torch.nn.Conv1d`
Reference
---------
Mirco Ravanelli, Yoshua Bengio,
"Speaker Recognition from raw waveform with SincNet".
https://arxiv.org/abs/1808.00158
"""
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, bias=False, groups=1,
sample_rate=16000, min_low_hz=50, min_band_hz=50):
super().__init__()
if in_channels != 1:
msg = (f'SincConv only support one input channel '
f'(here, in_channels = {in_channels:d}).')
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
if bias:
raise ValueError(f'SincConv does not support bias.')
if groups > 1:
raise ValueError(f'SincConv does not support groups.')
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
# initialize filterbanks such that they are equally spaced in Mel scale
low_hz = 30
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = np.linspace(self.to_mel(low_hz),
self.to_mel(high_hz),
self.out_channels + 1)
hz = self.to_hz(mel) / self.sample_rate
# filter lower frequency (out_channels, 1)
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
# filter frequency band (out_channels, 1)
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
# Hamming window
self.window_ = torch.hamming_window(self.kernel_size)
# (kernel_size, 1)
n = (self.kernel_size - 1) / 2
self.n_ = torch.range(-n, n).view(1, -1) / self.sample_rate
def sinc(self, x):
sinc = torch.sin(x) / x
sinc[:, self.kernel_size // 2] = 1.
return sinc
def forward(self, waveforms):
"""
Parameters
----------
waveforms : `torch.Tensor` (batch_size, 1, n_samples)
Batch of waveforms.
Returns
-------
features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
Batch of sinc filters activations.
"""
self.n_ = self.n_.to(waveforms.device)
self.window_ = self.window_.to(waveforms.device)
low = self.min_low_hz / self.sample_rate + torch.abs(self.low_hz_)
high = low + self.min_band_hz /self.sample_rate + torch.abs(self.band_hz_)
f_times_t = torch.matmul(low, self.n_)
low_pass1 = 2 * low * self.sinc(
2 * math.pi * f_times_t * self.sample_rate)
f_times_t = torch.matmul(high, self.n_)
low_pass2 = 2 * high * self.sinc(
2 * math.pi * f_times_t * self.sample_rate)
band_pass = low_pass2 - low_pass1
max_, _ = torch.max(band_pass, dim=1, keepdim=True)
band_pass = band_pass / max_
filters = (band_pass * self.window_).view(
self.out_channels, 1, self.kernel_size)
return F.conv1d(waveforms, filters, stride=self.stride,
padding=self.padding, dilation=self.dilation,
bias=None, groups=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment