Skip to content

Instantly share code, notes, and snippets.

@dschwertfeger
Created March 18, 2020 05:24
Show Gist options
  • Save dschwertfeger/6aa42626f5c720fa5d0b4467866bf916 to your computer and use it in GitHub Desktop.
Save dschwertfeger/6aa42626f5c720fa5d0b4467866bf916 to your computer and use it in GitHub Desktop.
A custom Keras layer to transform raw audio to log-mel-spectrograms
class LogMelSpectrogram(tf.keras.layers.Layer):
"""Compute log-magnitude mel-scaled spectrograms."""
def __init__(self, sample_rate, fft_size, hop_size, n_mels,
f_min=0.0, f_max=None, **kwargs):
super(LogMelSpectrogram, self).__init__(**kwargs)
self.sample_rate = sample_rate
self.fft_size = fft_size
self.hop_size = hop_size
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max if f_max else sample_rate / 2
self.mel_filterbank = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=self.n_mels,
num_spectrogram_bins=fft_size // 2 + 1,
sample_rate=self.sample_rate,
lower_edge_hertz=self.f_min,
upper_edge_hertz=self.f_max)
def build(self, input_shape):
self.non_trainable_weights.append(self.mel_filterbank)
super(LogMelSpectrogram, self).build(input_shape)
def call(self, waveforms):
"""Forward pass.
Parameters
----------
waveforms : tf.Tensor, shape = (None, n_samples)
A Batch of mono waveforms.
Returns
-------
log_mel_spectrograms : (tf.Tensor), shape = (None, time, freq, ch)
The corresponding batch of log-mel-spectrograms
"""
def _tf_log10(x):
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
def power_to_db(magnitude, amin=1e-16, top_db=80.0):
"""
https://librosa.github.io/librosa/generated/librosa.core.power_to_db.html
"""
ref_value = tf.reduce_max(magnitude)
log_spec = 10.0 * _tf_log10(tf.maximum(amin, magnitude))
log_spec -= 10.0 * _tf_log10(tf.maximum(amin, ref_value))
log_spec = tf.maximum(log_spec, tf.reduce_max(log_spec) - top_db)
return log_spec
spectrograms = tf.signal.stft(waveforms,
frame_length=self.fft_size,
frame_step=self.hop_size,
pad_end=False)
magnitude_spectrograms = tf.abs(spectrograms)
mel_spectrograms = tf.matmul(tf.square(magnitude_spectrograms),
self.mel_filterbank)
log_mel_spectrograms = power_to_db(mel_spectrograms)
# add channel dimension
log_mel_spectrograms = tf.expand_dims(log_mel_spectrograms, 3)
return log_mel_spectrograms
def get_config(self):
config = {
'fft_size': self.fft_size,
'hop_size': self.hop_size,
'n_mels': self.n_mels,
'sample_rate': self.sample_rate,
'f_min': self.f_min,
'f_max': self.f_max,
}
config.update(super(LogMelSpectrogram, self).get_config())
return config
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment