Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save vadimkantorov/58edfb48122d9f4819c29c48268ed37d to your computer and use it in GitHub Desktop.
Save vadimkantorov/58edfb48122d9f4819c29c48268ed37d to your computer and use it in GitHub Desktop.
Source code of the model from https://github.com/snakers4/silero-vad v4 extracted from the source code attributes embedded in the TorchScript structures
# I printed the code listings from the TorchScript silero_vad.jit's .code/_c.code attributes and tidied up the source a bit, nothing really fancy here
# This can be used for optimizing inference and enabling GPU inference
# Big thanks to the Silero company for making public their VAD checkpoint!
# The used checkpoint:
# https://github.com/snakers4/silero-vad/blob/a9d2b591dea11451d23aa4b480eff8e55dbd9d99/files/silero_vad.jit
import torch
import torch.nn as nn
class STFT(nn.Module):
def __init__(self, filter_length = 256, hop_length = 64):
super().__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.register_buffer('forward_basis_buffer', torch.zeros(258, 1, filter_length)) #TODO: initialize as cos/sin
#print(model_torchscript._model.feature_extractor.transform_.code, file = open('test.txt', 'w'))
#print(model_torchscript._model.feature_extractor.code, file = open('test.txt', 'w'))
def forward(self, input_data):
input_data0 = input_data.unsqueeze(1)
to_pad = int(torch.div(torch.sub(self.filter_length, self.hop_length), 2))
input_data1 = torch.nn.functional.pad(torch.unsqueeze(input_data0, 1), [to_pad, to_pad, 0, 0], "reflect")
forward_transform = torch.conv1d(torch.squeeze(input_data1, 1), self.forward_basis_buffer, None, [self.hop_length], [0])
cutoff = int(torch.add(torch.div(self.filter_length, 2), 1))
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
#phase = torch.atan2(imag_part, real_part)
#return (magnitude, phase)
return magnitude
class AdaptiveAudioNormalizationNew(nn.Module):
def __init__(self, to_pad = 3):
super().__init__()
self.to_pad = to_pad
self.filter_ = nn.Parameter(torch.zeros(1, 1, 2 * to_pad + 1))
#print(model_torchscript._model.adaptive_normalization.inlined_graph, file = open('test.txt', 'w'))
@staticmethod
def simple_pad(_mean_1, _to_pad_1):
_left_pad_1 = torch.flip(_mean_1[::1, ::1, 1 : _to_pad_1 + 1 : 1], [-1])
_right_pad_1 = torch.flip(_mean_1[::1, ::1, -1 - _to_pad_1: -1 : 1], [-1])
return torch.cat([_left_pad_1, _mean_1, _right_pad_1], 2)
#print(model_torchscript._model.adaptive_normalization.code, file = open('test.txt', 'w'))
def forward(self, spect):
spect0 = torch.log1p(spect * 1048576)
spect1 = torch.unsqueeze(spect0, 0) if spect0.ndim == 2 else spect0
mean0 = self.simple_pad(torch.mean(spect1, [1], True), self.to_pad)
mean1 = torch.conv1d(mean0, self.filter_)
mean_mean = torch.mean(mean1, [-1], True)
return spect1 + (-mean_mean)
class ConvBlock(nn.Module):
def __init__(self, in_channels = 258, out_channels = 16, proj = False):
super().__init__()
self.dw_conv = nn.Sequential(nn.Conv1d(in_channels, in_channels, 5, padding = 2, groups = in_channels), nn.Identity(), nn.ReLU())
self.pw_conv = nn.Sequential(nn.Conv1d(in_channels, out_channels, 1), nn.Identity())
self.proj = nn.Conv1d(in_channels, out_channels, 1) if proj else nn.Identity()
self.activation = nn.ReLU()
#print(getattr(model_torchscript._model.first_layer, "0").code, file = open('test.txt', 'w'))
def forward(self, x):
residual = self.proj(x)
x0 = self.pw_conv(self.dw_conv(x))
x0 += residual
return self.activation(x0)
class VADDecoderRNNJIT(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.LSTM(64, 64, num_layers = 2, batch_first = True, dropout = 0.1)
self.decoder = nn.Sequential(nn.ReLU(), nn.Conv1d(64, 1, 1), nn.Sigmoid())
#print(model_torchscript._model.decoder.code, file = open('test.txt', 'w'))
def forward(self, x, h=torch.Tensor(), c=torch.Tensor()):
x, (h, c), = self.rnn(torch.permute(x, [0, 2, 1]), (h, c) if h.numel() > 0 else None)
return (self.decoder(torch.permute(x, [0, 2, 1])), h, c)
class VADRNNJIT(nn.Module):
def __init__(self):
super().__init__()
self.feature_extractor = STFT()
self.adaptive_normalization = AdaptiveAudioNormalizationNew()
self.first_layer = nn.Sequential(ConvBlock(258, 16, proj = True), nn.Dropout(0.15))
self.encoder = nn.Sequential(nn.Conv1d(16, 16, 1, stride = 2), nn.BatchNorm1d(16), nn.ReLU(), nn.Sequential(ConvBlock(16, 32, proj = True), nn.Dropout(0.15)), nn.Conv1d(32, 32, 1, stride = 2), nn.BatchNorm1d(32), nn.ReLU(), nn.Sequential(ConvBlock(32, 32, proj = False),nn.Dropout(0.15)), nn.Conv1d(32, 32, 1, stride = 2), nn.BatchNorm1d(32), nn.ReLU(), nn.Sequential(ConvBlock(32, 64, proj = True), nn.Dropout(0.15)), nn.Conv1d(64, 64, 1, stride = 1), nn.BatchNorm1d(64), nn.ReLU())
self.decoder = VADDecoderRNNJIT()
#print(model_torchscript._model.code, file = open('test.txt', 'w'))
def forward(self, x, h = torch.Tensor(), c = torch.Tensor()):
x0 = self.feature_extractor(x)
norm = self.adaptive_normalization(x0)
x1 = torch.cat([x0, norm], 1)
x2 = self.first_layer(x1)
x3 = self.encoder(x2)
x4, h0, c0, = self.decoder(x3, h, c)
out = torch.unsqueeze(torch.mean(torch.squeeze(x4, 1), [1]), 1)
return (out, h0, c0)
class VADRNNJITMerge(nn.Module):
def __init__(self):
super().__init__()
self._model = VADRNNJIT()
self._model_8k = VADRNNJIT()
self._last_batch_size = None
self._last_sr = None
self._h = None
self.sample_rates = [8000, 16000]
self.reset_states()
#print(s.reset_states.code, file = open('test.txt', 'w'))
def reset_states(self, batch_size = 1):
self._h = torch.zeros([0])
self._c = torch.zeros([0])
#self._h = torch.zeros((2, batch_size, 64), dtype = torch.float32)
#self._c = torch.zeros((2, batch_size, 64), dtype = torch.float32)
self._last_sr = 0
self._last_batch_size = 0
#print(s._validate_input.code, file = open('test.txt', 'w'))
def _validate_input(self, x, sr):
x1 = torch.unsqueeze(x, 0) if x.ndim == 1 else x
assert x1.ndim == 2, f"Too many dimensions for input audio chunk {x1.ndim}"
sr1, x2 = (16000, x1[:, ::sr // 16000]) if sr != 16000 and sr % 16000 == 0 else (sr, x1)
assert sr1 in self.sample_rates, f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)"
assert sr1 / x2.shape[1] <= 31.25, "Input audio chunk is too short"
return (x2, sr1)
#print(s.code, file = open('test.txt', 'w'))
def forward(self, x, sr):
x0, sr0, = self._validate_input(x, sr)
if self._last_sr and self._last_sr != sr0:
self.reset_states()
if self._last_batch_size and self._last_batch_size != x0.shape[0]:
self.reset_states()
assert sr0 == 16000 or sr0 == 8000
out, self._h, self._c, = (self._model_8k if sr == 8000 else self._model) (x0, self._h, self._c)
self._last_sr = sr0
self._last_batch_size = self._h.shape[1]
return out
#print(model_torchscript.audio_forward.code, file = open('test.txt', 'w'))
def audio_forward(self, x, sr, num_samples: int = 512):
x, sr = self._validate_input(x, sr)
if x.shape[1] % num_samples:
pad_num = num_samples - (x.shape[1] % num_samples)
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
self.reset_states(x.shape[0])
outs = [self(x[:, i:i+num_samples], sr) for i in range(0, x.shape[1], num_samples)]
return torch.cat(outs, dim=1)
if __name__ == '__main__':
silero_torchscript_checkpoint = 'silero_vad.jit'
import hashlib; assert '22aced3da46b9d9546686310f779818e' == hashlib.md5(open(silero_torchscript_checkpoint,'rb').read()).hexdigest()
model_torchscript = torch.jit.load(silero_torchscript_checkpoint)
model_torchscript.eval()
state_dict = model_torchscript.state_dict()
print(model_torchscript)
model = VADRNNJITMerge()
model.eval()
model.load_state_dict(state_dict)
print(model)
torch.set_grad_enabled(False)
torch.set_num_threads(1)
import torchaudio
samples_CT, sample_rate = torchaudio.load('ru.wav') # https://models.silero.ai/vad_models/ru.wav
assert sample_rate == 16000
model_torchscript.reset_states()
speech_prob_torchscript = model_torchscript(samples_CT, sample_rate)
model_torchscript.reset_states()
speech_prob_torchscript_batch = model_torchscript.audio_forward(samples_CT, sample_rate)
print(speech_prob_torchscript, speech_prob_torchscript_batch, speech_prob_torchscript_batch.shape)
model.reset_states()
speech_prob = model(samples_CT, sample_rate)
model.reset_states()
speech_prob_batch = model.audio_forward(samples_CT, sample_rate)
print(speech_prob, speech_prob_batch, speech_prob_batch.shape)
assert torch.allclose(speech_prob_torchscript, speech_prob)
assert torch.allclose(speech_prob_torchscript_batch, speech_prob_batch)
RecursiveScriptModule(
original_name=VADRNNJITMerge
(_model): RecursiveScriptModule(
original_name=VADRNNJIT
(adaptive_normalization): RecursiveScriptModule(original_name=AdaptiveAudioNormalizationNew)
(feature_extractor): RecursiveScriptModule(original_name=STFT)
(first_layer): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(
original_name=ConvBlock
(dw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
(2): RecursiveScriptModule(original_name=ReLU)
)
(pw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
)
(proj): RecursiveScriptModule(original_name=Conv1d)
(activation): RecursiveScriptModule(original_name=ReLU)
)
(1): RecursiveScriptModule(original_name=Dropout)
)
(encoder): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=BatchNorm1d)
(2): RecursiveScriptModule(original_name=ReLU)
(3): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(
original_name=ConvBlock
(dw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
(2): RecursiveScriptModule(original_name=ReLU)
)
(pw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
)
(proj): RecursiveScriptModule(original_name=Conv1d)
(activation): RecursiveScriptModule(original_name=ReLU)
)
(1): RecursiveScriptModule(original_name=Dropout)
)
(4): RecursiveScriptModule(original_name=Conv1d)
(5): RecursiveScriptModule(original_name=BatchNorm1d)
(6): RecursiveScriptModule(original_name=ReLU)
(7): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(
original_name=ConvBlock
(dw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
(2): RecursiveScriptModule(original_name=ReLU)
)
(pw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
)
(activation): RecursiveScriptModule(original_name=ReLU)
)
(1): RecursiveScriptModule(original_name=Dropout)
)
(8): RecursiveScriptModule(original_name=Conv1d)
(9): RecursiveScriptModule(original_name=BatchNorm1d)
(10): RecursiveScriptModule(original_name=ReLU)
(11): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(
original_name=ConvBlock
(dw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
(2): RecursiveScriptModule(original_name=ReLU)
)
(pw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
)
(proj): RecursiveScriptModule(original_name=Conv1d)
(activation): RecursiveScriptModule(original_name=ReLU)
)
(1): RecursiveScriptModule(original_name=Dropout)
)
(12): RecursiveScriptModule(original_name=Conv1d)
(13): RecursiveScriptModule(original_name=BatchNorm1d)
(14): RecursiveScriptModule(original_name=ReLU)
)
(decoder): RecursiveScriptModule(
original_name=VADDecoderRNNJIT
(rnn): RecursiveScriptModule(original_name=LSTM)
(decoder): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=ReLU)
(1): RecursiveScriptModule(original_name=Conv1d)
(2): RecursiveScriptModule(original_name=Sigmoid)
)
)
)
(_model_8k): RecursiveScriptModule(
original_name=VADRNNJIT
(adaptive_normalization): RecursiveScriptModule(original_name=AdaptiveAudioNormalizationNew)
(feature_extractor): RecursiveScriptModule(original_name=STFT)
(first_layer): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(
original_name=ConvBlock
(dw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
(2): RecursiveScriptModule(original_name=ReLU)
)
(pw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
)
(proj): RecursiveScriptModule(original_name=Conv1d)
(activation): RecursiveScriptModule(original_name=ReLU)
)
(1): RecursiveScriptModule(original_name=Dropout)
)
(encoder): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=BatchNorm1d)
(2): RecursiveScriptModule(original_name=ReLU)
(3): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(
original_name=ConvBlock
(dw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
(2): RecursiveScriptModule(original_name=ReLU)
)
(pw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
)
(proj): RecursiveScriptModule(original_name=Conv1d)
(activation): RecursiveScriptModule(original_name=ReLU)
)
(1): RecursiveScriptModule(original_name=Dropout)
)
(4): RecursiveScriptModule(original_name=Conv1d)
(5): RecursiveScriptModule(original_name=BatchNorm1d)
(6): RecursiveScriptModule(original_name=ReLU)
(7): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(
original_name=ConvBlock
(dw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
(2): RecursiveScriptModule(original_name=ReLU)
)
(pw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
)
(activation): RecursiveScriptModule(original_name=ReLU)
)
(1): RecursiveScriptModule(original_name=Dropout)
)
(8): RecursiveScriptModule(original_name=Conv1d)
(9): RecursiveScriptModule(original_name=BatchNorm1d)
(10): RecursiveScriptModule(original_name=ReLU)
(11): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(
original_name=ConvBlock
(dw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
(2): RecursiveScriptModule(original_name=ReLU)
)
(pw_conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv1d)
(1): RecursiveScriptModule(original_name=Identity)
)
(proj): RecursiveScriptModule(original_name=Conv1d)
(activation): RecursiveScriptModule(original_name=ReLU)
)
(1): RecursiveScriptModule(original_name=Dropout)
)
(12): RecursiveScriptModule(original_name=Conv1d)
(13): RecursiveScriptModule(original_name=BatchNorm1d)
(14): RecursiveScriptModule(original_name=ReLU)
)
(decoder): RecursiveScriptModule(
original_name=VADDecoderRNNJIT
(rnn): RecursiveScriptModule(original_name=LSTM)
(decoder): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=ReLU)
(1): RecursiveScriptModule(original_name=Conv1d)
(2): RecursiveScriptModule(original_name=Sigmoid)
)
)
)
)
VADRNNJITMerge(
(_model): VADRNNJIT(
(feature_extractor): STFT()
(adaptive_normalization): AdaptiveAudioNormalizationNew()
(first_layer): Sequential(
(0): ConvBlock(
(dw_conv): Sequential(
(0): Conv1d(258, 258, kernel_size=(5,), stride=(1,), padding=(2,), groups=258)
(1): Identity()
(2): ReLU()
)
(pw_conv): Sequential(
(0): Conv1d(258, 16, kernel_size=(1,), stride=(1,))
(1): Identity()
)
(proj): Conv1d(258, 16, kernel_size=(1,), stride=(1,))
(activation): ReLU()
)
(1): Dropout(p=0.15, inplace=False)
)
(encoder): Sequential(
(0): Conv1d(16, 16, kernel_size=(1,), stride=(2,))
(1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Sequential(
(0): ConvBlock(
(dw_conv): Sequential(
(0): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,), groups=16)
(1): Identity()
(2): ReLU()
)
(pw_conv): Sequential(
(0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
(1): Identity()
)
(proj): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
(activation): ReLU()
)
(1): Dropout(p=0.15, inplace=False)
)
(4): Conv1d(32, 32, kernel_size=(1,), stride=(2,))
(5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU()
(7): Sequential(
(0): ConvBlock(
(dw_conv): Sequential(
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), groups=32)
(1): Identity()
(2): ReLU()
)
(pw_conv): Sequential(
(0): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(1): Identity()
)
(proj): Identity()
(activation): ReLU()
)
(1): Dropout(p=0.15, inplace=False)
)
(8): Conv1d(32, 32, kernel_size=(1,), stride=(2,))
(9): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU()
(11): Sequential(
(0): ConvBlock(
(dw_conv): Sequential(
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), groups=32)
(1): Identity()
(2): ReLU()
)
(pw_conv): Sequential(
(0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
(1): Identity()
)
(proj): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
(activation): ReLU()
)
(1): Dropout(p=0.15, inplace=False)
)
(12): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
(13): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(14): ReLU()
)
(decoder): VADDecoderRNNJIT(
(rnn): LSTM(64, 64, num_layers=2, batch_first=True, dropout=0.1)
(decoder): Sequential(
(0): ReLU()
(1): Conv1d(64, 1, kernel_size=(1,), stride=(1,))
(2): Sigmoid()
)
)
)
(_model_8k): VADRNNJIT(
(feature_extractor): STFT()
(adaptive_normalization): AdaptiveAudioNormalizationNew()
(first_layer): Sequential(
(0): ConvBlock(
(dw_conv): Sequential(
(0): Conv1d(258, 258, kernel_size=(5,), stride=(1,), padding=(2,), groups=258)
(1): Identity()
(2): ReLU()
)
(pw_conv): Sequential(
(0): Conv1d(258, 16, kernel_size=(1,), stride=(1,))
(1): Identity()
)
(proj): Conv1d(258, 16, kernel_size=(1,), stride=(1,))
(activation): ReLU()
)
(1): Dropout(p=0.15, inplace=False)
)
(encoder): Sequential(
(0): Conv1d(16, 16, kernel_size=(1,), stride=(2,))
(1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Sequential(
(0): ConvBlock(
(dw_conv): Sequential(
(0): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,), groups=16)
(1): Identity()
(2): ReLU()
)
(pw_conv): Sequential(
(0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
(1): Identity()
)
(proj): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
(activation): ReLU()
)
(1): Dropout(p=0.15, inplace=False)
)
(4): Conv1d(32, 32, kernel_size=(1,), stride=(2,))
(5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU()
(7): Sequential(
(0): ConvBlock(
(dw_conv): Sequential(
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), groups=32)
(1): Identity()
(2): ReLU()
)
(pw_conv): Sequential(
(0): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(1): Identity()
)
(proj): Identity()
(activation): ReLU()
)
(1): Dropout(p=0.15, inplace=False)
)
(8): Conv1d(32, 32, kernel_size=(1,), stride=(2,))
(9): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU()
(11): Sequential(
(0): ConvBlock(
(dw_conv): Sequential(
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), groups=32)
(1): Identity()
(2): ReLU()
)
(pw_conv): Sequential(
(0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
(1): Identity()
)
(proj): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
(activation): ReLU()
)
(1): Dropout(p=0.15, inplace=False)
)
(12): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
(13): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(14): ReLU()
)
(decoder): VADDecoderRNNJIT(
(rnn): LSTM(64, 64, num_layers=2, batch_first=True, dropout=0.1)
(decoder): Sequential(
(0): ReLU()
(1): Conv1d(64, 1, kernel_size=(1,), stride=(1,))
(2): Sigmoid()
)
)
)
)
tensor([[0.5635]]) tensor([[0.0592, 0.0315, 0.0283, ..., 0.3480, 0.3140, 0.2010]]) torch.Size([1, 1875])
tensor([[0.5635]]) tensor([[0.0592, 0.0315, 0.0283, ..., 0.3480, 0.3140, 0.2010]]) torch.Size([1, 1875])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment