Created
February 9, 2022 23:55
-
-
Save philippgovernale/f9ea72dfd68011cf868a6646cfe71234 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2019 RnD at Spoon Radio | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""SpecAugment Implementation for Tensorflow. | |
Related paper : https://arxiv.org/pdf/1904.08779.pdf | |
In this paper, show summarized parameters by each open datasets in Tabel 1. | |
----------------------------------------- | |
Policy | W | F | m_F | T | p | m_T | |
----------------------------------------- | |
None | 0 | 0 | - | 0 | - | - | |
----------------------------------------- | |
LB | 80 | 27 | 1 | 100 | 1.0 | 1 | |
----------------------------------------- | |
LD | 80 | 27 | 2 | 100 | 1.0 | 2 | |
----------------------------------------- | |
SM | 40 | 15 | 2 | 70 | 0.2 | 2 | |
----------------------------------------- | |
SS | 40 | 27 | 2 | 70 | 0.2 | 2 | |
----------------------------------------- | |
LB : LibriSpeech basic | |
LD : LibriSpeech double | |
SM : Switchboard mild | |
SS : Switchboard strong | |
""" | |
import librosa | |
import librosa.display | |
import tensorflow as tf | |
from tensorflow_addons.image import sparse_image_warp | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def sparse_warp(mel_spectrogram, time_warping_para=10): | |
"""Spec augmentation Calculation Function. | |
'SpecAugment' have 3 steps for audio data augmentation. | |
first step is time warping using Tensorflow's image_sparse_warp function. | |
Second step is frequency masking, last step is time masking. | |
# Arguments: | |
mel_spectrogram(numpy array): audio file path of you want to warping and masking. | |
time_warping_para(float): Augmentation parameter, "time warp parameter W". | |
If none, default = 80 for LibriSpeech. | |
# Returns | |
mel_spectrogram(numpy array): warped and masked mel spectrogram. | |
""" | |
fbank_size = tf.shape(mel_spectrogram) | |
#n, v = fbank_size[1], fbank_size[2] | |
v, n = fbank_size[1], fbank_size[2] | |
# Step 1 : Time warping | |
# Image warping control point setting. | |
# Source | |
pt = tf.random.uniform([], time_warping_para, n-time_warping_para, tf.int32) # radnom point along the time axis | |
src_ctr_pt_freq = tf.range(v // 2) # control points on freq-axis | |
src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt # control points on time-axis | |
src_ctr_pts = tf.stack((src_ctr_pt_freq, src_ctr_pt_time), -1) | |
src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32) | |
# Destination | |
w = tf.random.uniform([], -time_warping_para, time_warping_para, tf.int32) # distance | |
dest_ctr_pt_freq = src_ctr_pt_freq | |
dest_ctr_pt_time = src_ctr_pt_time + w | |
dest_ctr_pts = tf.stack((dest_ctr_pt_freq, dest_ctr_pt_time), -1) | |
dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32) | |
# warp | |
source_control_point_locations = tf.expand_dims(src_ctr_pts, 0) # (1, v//2, 2) | |
dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) # (1, v//2, 2) | |
warped_image,_ = sparse_image_warp(mel_spectrogram, | |
source_control_point_locations, | |
dest_control_point_locations, num_boundary_points=2) | |
return warped_image | |
def frequency_masking(mel_spectrogram, v, frequency_masking_para=10, frequency_mask_num=2): | |
"""Spec augmentation Calculation Function. | |
'SpecAugment' have 3 steps for audio data augmentation. | |
first step is time warping using Tensorflow's image_sparse_warp function. | |
Second step is frequency masking, last step is time masking. | |
# Arguments: | |
mel_spectrogram(numpy array): audio file path of you want to warping and masking. | |
frequency_masking_para(float): Augmentation parameter, "frequency mask parameter F" | |
If none, default = 100 for LibriSpeech. | |
frequency_mask_num(float): number of frequency masking lines, "m_F". | |
If none, default = 1 for LibriSpeech. | |
# Returns | |
mel_spectrogram(numpy array): warped and masked mel spectrogram. | |
""" | |
# Step 2 : Frequency masking | |
fbank_size = tf.shape(mel_spectrogram) | |
v,n = fbank_size[1], fbank_size[2] | |
for i in range(frequency_mask_num): | |
f = tf.random.uniform([], minval=0, maxval=frequency_masking_para, dtype=tf.int32) | |
v = tf.cast(v, dtype=tf.int32) | |
f0 = tf.random.uniform([], minval=0, maxval=v-f, dtype=tf.int32) | |
# warped_mel_spectrogram[f0:f0 + f, :] = 0 | |
mask = tf.concat((tf.ones(shape=(1, v - f0 - f,n, 1)), | |
tf.zeros(shape=(1, f, n, 1)), | |
tf.ones(shape=(1, f0, n,1)), | |
), 1) | |
mel_spectrogram = mel_spectrogram * mask | |
return tf.cast(mel_spectrogram, dtype=tf.float32) | |
def time_masking(mel_spectrogram, tau, time_masking_para=20, time_mask_num=2): | |
"""Spec augmentation Calculation Function. | |
'SpecAugment' have 3 steps for audio data augmentation. | |
first step is time warping using Tensorflow's image_sparse_warp function. | |
Second step is frequency masking, last step is time masking. | |
# Arguments: | |
mel_spectrogram(numpy array): audio file path of you want to warping and masking. | |
time_masking_para(float): Augmentation parameter, "time mask parameter T" | |
If none, default = 27 for LibriSpeech. | |
time_mask_num(float): number of time masking lines, "m_T". | |
If none, default = 1 for LibriSpeech. | |
# Returns | |
mel_spectrogram(numpy array): warped and masked mel spectrogram. | |
""" | |
fbank_size = tf.shape(mel_spectrogram) | |
v,n = fbank_size[1], fbank_size[2] | |
# Step 3 : Time masking | |
for i in range(time_mask_num): | |
t = tf.random.uniform([], minval=0, maxval=time_masking_para, dtype=tf.int32) | |
t0 = tf.random.uniform([], minval=0, maxval=n-t, dtype=tf.int32) | |
# mel_spectrogram[:, t0:t0 + t] = 0 | |
mask = tf.concat((tf.ones(shape=(1, v, n-t0-t, 1)), | |
tf.zeros(shape=(1, v,t,1)), | |
tf.ones(shape=(1, v, t0,1)), | |
), 2) | |
mel_spectrogram = mel_spectrogram * mask | |
return tf.cast(mel_spectrogram, dtype=tf.float32) | |
def spec_augment(mel_spectrogram): | |
v = mel_spectrogram.shape[0] | |
tau = mel_spectrogram.shape[1] | |
warped_mel_spectrogram = sparse_warp(mel_spectrogram) | |
warped_frequency_spectrogram = frequency_masking(warped_mel_spectrogram, v=v) | |
warped_frequency_time_sepctrogram = time_masking(warped_frequency_spectrogram, tau=tau) | |
return warped_frequency_time_sepctrogram | |
#return warped_mel_spectrogram | |
def visualization_spectrogram(mel_spectrogram, title): | |
"""visualizing first one result of SpecAugment | |
# Arguments: | |
mel_spectrogram(ndarray): mel_spectrogram to visualize. | |
title(String): plot figure's title | |
""" | |
# Show mel-spectrogram using librosa's specshow. | |
plt.figure() | |
librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :, 0], ref=np.max), y_axis='mel', fmax=8000, x_axis='time') | |
#librosa.display.specshow(mel_spectrogram[0, :, :, 0], y_axis='mel', fmax=8000, x_axis='time') | |
plt.title(title) | |
plt.tight_layout() | |
plt.show() | |
def visualization_tensor_spectrogram(mel_spectrogram, title): | |
"""visualizing first one result of SpecAugment | |
# Arguments: | |
mel_spectrogram(ndarray): mel_spectrogram to visualize. | |
title(String): plot figure's title | |
""" | |
# Show mel-spectrogram using librosa's specshow. | |
plt.figure(figsize=(10, 4)) | |
librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :, 0], ref=np.max), y_axis='mel', fmax=8000, x_axis='time') | |
# plt.colorbar(format='%+2.0f dB') | |
plt.title(title) | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment