Skip to content

Instantly share code, notes, and snippets.

@keunwoochoi
Created October 7, 2019 18:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save keunwoochoi/f4854acb68acf791a49a051893bcd23b to your computer and use it in GitHub Desktop.
Save keunwoochoi/f4854acb68acf791a49a051893bcd23b to your computer and use it in GitHub Desktop.
melgram test code
import tensorflow as tf
import numpy as np
import librosa
class MelgramTest(tf.test.TestCase):
def test_layer(self):
with self.session() as sess:
test_sr = 44100
# tensorflow
mel_layer = LogMelgramLayer(
num_fft=1024,
hop_length=512,
num_mels=128,
sample_rate=test_sr,
f_min=0.0,
f_max=test_sr // 2,
eps=1e-6,
)
np.random.seed(123)
src = np.random.randn(test_sr, ).astype(np.float32)
tf_logmelgram = sess.run(mel_layer(src.reshape(1, -1)))[0, :, :, 0] # single item, remove channel axis
# librosa
librosa_stft = np.abs(librosa.stft(y=src,
n_fft=1024,
hop_length=512,
center=False,
win_length=1024))
linear_to_mel = librosa.filters.mel(sr=test_sr,
n_fft=1024,
n_mels=128,
fmin=0,
fmax=test_sr // 2,
htk=True,
norm=None)
librosa_melgram = np.dot(librosa_stft.T ** 2, linear_to_mel.T).astype(np.float32)
librosa_logmelgram = np.log10(librosa_melgram + 1e-6)
# result
# - Max absolute difference: 0.00492716
# - Max relative difference: 0.10149292
self.assertEqual(tf_logmelgram.shape, librosa_logmelgram.shape)
self.assertAllClose(librosa_logmelgram, tf_logmelgram, rtol=1e-3, atol=1e-2)
import tensorflow as tf
class LogMelgramLayer(tf.keras.layers.Layer):
def __init__(
self, num_fft, hop_length, num_mels, sample_rate, f_min, f_max, eps, **kwargs
):
super(LogMelgramLayer, self).__init__(**kwargs)
self.num_fft = num_fft
self.hop_length = hop_length
self.num_mels = num_mels
self.sample_rate = sample_rate
self.f_min = f_min
self.f_max = f_max
self.eps = eps
self.num_freqs = num_fft // 2 + 1
lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=self.num_mels,
num_spectrogram_bins=self.num_freqs,
sample_rate=self.sample_rate,
lower_edge_hertz=self.f_min,
upper_edge_hertz=self.f_max,
)
self.lin_to_mel_matrix = lin_to_mel_matrix
def build(self, input_shape):
self.non_trainable_weights.append(self.lin_to_mel_matrix)
super(LogMelgramLayer, self).build(input_shape)
def call(self, input):
"""
Args:
input (tensor): Batch of mono waveform, shape: (None, N)
Returns:
log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1)
"""
def _tf_log10(x):
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
stfts = tf.signal.stft(
input,
frame_length=self.num_fft,
frame_step=self.hop_length,
pad_end=False, # librosa test compatibility
)
mag_stfts = tf.abs(stfts)
melgrams = tf.tensordot( # assuming channel_first, so (b, c, f, t)
tf.square(mag_stfts), self.lin_to_mel_matrix, axes=[2, 0]
)
log_melgrams = _tf_log10(melgrams + self.eps)
return tf.expand_dims(log_melgrams, 3)
def get_config(self):
config = {
'num_fft': self.num_fft,
'hop_length': self.hop_length,
'num_mels': self.num_mels,
'sample_rate': self.sample_rate,
'f_min': self.f_min,
'f_max': self.f_max,
'eps': self.eps,
}
base_config = super(LogMelgramLayer, self).get_config()
return dict(list(config.items()) + list(base_config.items()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment