Skip to content

Instantly share code, notes, and snippets.

@sourabh2k15
Last active May 2, 2023 03:25
Show Gist options
  • Save sourabh2k15/80adbf1c5861e727f7698fd66e51be39 to your computer and use it in GitHub Desktop.
Save sourabh2k15/80adbf1c5861e727f7698fd66e51be39 to your computer and use it in GitHub Desktop.
#@title FFT Layer
"""Flax layer to perform preprocessing on librispeech audio inputs.
This layer computes windowed short time fourier transform over audio signals
then converts it to mel scale and finally takes a logarithm of resulting
mel spectrograms and normalizes it to be used in speech recognition models.
This code is based on lingvo's librispeech preprocessing code here:
https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/asr/frontend.py
"""
from typing import Any, Optional, Union
from flax import linen as nn
from flax import struct
import jax
import jax.numpy as jnp
import numpy as np
# mel spectrum constants.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0
LIBRISPEECH_MEAN_VECTOR = [
-7.6047816276550293,
-7.1206226348876953,
-6.8864245414733887,
-6.8705768585205078,
-6.9667720794677734,
-7.1084094047546387,
-6.9528026580810547,
-6.783994197845459,
-6.6195521354675293,
-6.4876265525817871,
-6.4120659828186035,
-6.394047737121582,
-6.4244871139526367,
-6.3993711471557617,
-6.5158271789550781,
-6.7137999534606934,
-6.8476877212524414,
-6.9885001182556152,
-6.9221386909484863,
-7.146148681640625,
-7.2040400505065918,
-7.0537552833557129,
-7.3140382766723633,
-7.1223249435424805,
-7.30251407623291,
-7.1212143898010254,
-7.2425732612609863,
-7.1730537414550781,
-7.0979413986206055,
-7.088747501373291,
-6.9849910736083984,
-6.8787732124328613,
-6.7602753639221191,
-6.6300945281982422,
-6.5145769119262695,
-6.4245057106018066,
-6.356513500213623,
-6.31787633895874,
-6.2660770416259766,
-6.2468328475952148,
-6.2821526527404785,
-6.1908388137817383,
-6.2484354972839355,
-6.1472640037536621,
-6.0924725532531738,
-6.0171003341674805,
-5.9250402450561523,
-5.8535833358764648,
-5.8209109306335449,
-5.8118929862976074,
-5.80783748626709,
-5.7714629173278809,
-5.7453732490539551,
-5.7705655097961426,
-5.7765641212463379,
-5.7831673622131348,
-5.7954087257385254,
-5.7994823455810547,
-5.8023476600646973,
-5.8047118186950684,
-5.8168182373046875,
-5.8844799995422363,
-5.9727106094360352,
-6.0444660186767578,
-6.1284866333007812,
-6.2257585525512695,
-6.3157496452331543,
-6.39061164855957,
-6.4928598403930664,
-6.5498456954956055,
-6.6054320335388184,
-6.6508378982543945,
-6.66917610168457,
-6.6726889610290527,
-6.684234619140625,
-6.6974577903747559,
-6.75471830368042,
-6.7949142456054688,
-6.8634209632873535,
-6.94186544418335
]
LIBRISPEECH_STD_VECTOR = [
3.4353282451629639,
3.5962932109832764,
3.7012472152709961,
3.7369205951690674,
3.7535104751586914,
3.693629264831543,
3.6922497749328613,
3.7641522884368896,
3.8419716358184814,
3.8999848365783691,
3.9294240474700928,
3.9317409992218018,
3.9139585494995117,
3.9031598567962646,
3.8691999912261963,
3.8155081272125244,
3.7644970417022705,
3.7099106311798096,
3.6965086460113525,
3.6003766059875488,
3.5493226051330566,
3.5465121269226074,
3.45003604888916,
3.4712812900543213,
3.4084610939025879,
3.4408135414123535,
3.4104881286621094,
3.4217638969421387,
3.4312851428985596,
3.4199209213256836,
3.4305806159973145,
3.4382665157318115,
3.4580366611480713,
3.4817991256713867,
3.4958710670471191,
3.5036792755126953,
3.5047574043273926,
3.4988734722137451,
3.493056058883667,
3.4822943210601807,
3.459430456161499,
3.4612770080566406,
3.4559063911437988,
3.4755423069000244,
3.4971549510955811,
3.5326557159423828,
3.5705199241638184,
3.5920312404632568,
3.596907377243042,
3.5913500785827637,
3.5865931510925293,
3.5826809406280518,
3.5837743282318115,
3.5895791053771973,
3.5819313526153564,
3.5837869644165039,
3.5861184597015381,
3.5889589786529541,
3.592214822769165,
3.5939455032348633,
3.5856630802154541,
3.5884113311767578,
3.5921022891998291,
3.5870490074157715,
3.5806570053100586,
3.5731067657470703,
3.5617532730102539,
3.54980731010437,
3.5527374744415283,
3.5475366115570068,
3.5387849807739258,
3.5256178379058838,
3.5031836032867432,
3.4922726154327393,
3.4879646301269531,
3.4725594520568848,
3.4558389186859131,
3.4351828098297119,
3.4284293651580811,
3.4299170970916748
]
@struct.dataclass
class LibrispeechPreprocessingConfig:
"""Config to hold all preprocessing options for librispeech dataset."""
sample_rate: float = 16000.0
frame_size_ms: float = 25.0
frame_step_ms: float = 10.0
compute_energy: bool = True
window_fn: str = 'HANNING'
output_log_floor: float = 1.0
pad_end: bool = False
preemph: float = 0.97
preemph_htk_flavor: bool = True
noise_scale: float = 0.0
num_bins: int = 80
lower_edge_hertz: float = 125.0
upper_edge_hertz: float = 7600.0
fft_overdrive: bool = False
output_floor: float = 0.000010
def _hertz_to_mel(frequencies_hertz):
"""Convert hertz to mel."""
return _MEL_HIGH_FREQUENCY_Q * jnp.log(1.0 + (frequencies_hertz /
_MEL_BREAK_FREQUENCY_HERTZ))
def _pad_end_length(num_timesteps, frame_step, frame_size):
"""Returns how many sample needed to be padded for pad_end feature."""
# The number of frames that can be extracted from the signal.
num_frames = int(np.ceil(num_timesteps / frame_step))
# Signal length required for computing `num_frames` frames.
padded_length = frame_step * (num_frames - 1) + frame_size
return padded_length - num_timesteps
def frame(x,
frame_length: int,
frame_step: int,
pad_end: bool = False,
pad_value: Union[int, float] = 0.0):
"""Slides a window and extract values.
This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with
stride of `frame_step`, and returns an array `y` with the shape
`(batch_size, num_frames, frame_length, num_channels)`. Unlike the
counterpart in Tensorflow (`tf.signal.frame`), this function currently does
not take `axis` argument, and the input tensor `x` is expected to have a
shape of `(batch_size, timesteps, channels)`.
Args:
x: An input array with `(batch_size, timesteps, channels)`-shape.
frame_length: The frame length.
frame_step: The frame hop size.
pad_end: If True, the end of signal is padded so the window can continue
sliding while the starting point of the window is in the valid range.
pad_value: A scalar used as a padding value when `pad_end` is True.
Returns:
A tensor with shape `(batch_size, num_frames, frame_length, num_chennels)`.
"""
_, num_timesteps, num_channels = x.shape
if pad_end:
num_extends = _pad_end_length(num_timesteps, frame_step, frame_length)
x = jnp.pad(
x, ((0, 0), (0, num_extends), (0, 0)),
'constant',
constant_values=pad_value)
flat_y = jax.lax.conv_general_dilated_patches(
x, (frame_length,), (frame_step,),
'VALID',
dimension_numbers=('NTC', 'OIT', 'NTC'))
ret = flat_y.reshape(flat_y.shape[:-1] + (num_channels, frame_length))
return ret.transpose((0, 1, 3, 2))
def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
num_spectrogram_bins: int = 129,
sample_rate: Union[int, float] = 8000,
lower_edge_hertz: Union[int, float] = 125.0,
upper_edge_hertz: Union[int, float] = 3800.0,
dtype: Any = jnp.float32):
r"""Jax-port of `tf.signal.linear_to_mel_weight_matrix`.
Args:
num_mel_bins: Python int. How many bands in the resulting mel spectrum.
num_spectrogram_bins: An integer `Tensor`. How many bins there are in the
source spectrogram data, which is understood to be `fft_size // 2 + 1`,
i.e. the spectrogram only contains the nonredundant FFT bins.
sample_rate: An integer or float `Tensor`. Samples per second of the input
signal used to create the spectrogram. Used to figure out the frequencies
corresponding to each spectrogram bin, which dictates how they are mapped
into the mel scale.
lower_edge_hertz: Python float. Lower bound on the frequencies to be
included in the mel spectrum. This corresponds to the lower edge of the
lowest triangular band.
upper_edge_hertz: Python float. The desired top edge of the highest
frequency band.
dtype: The `DType` of the result matrix. Must be a floating point type.
Returns:
An array of shape `[num_spectrogram_bins, num_mel_bins]`.
Raises:
ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
ordered, `upper_edge_hertz` is larger than the Nyquist frequency.
[mel]: https://en.wikipedia.org/wiki/Mel_scale
"""
# Input validator from tensorflow/python/ops/signal/mel_ops.py#L71
if num_mel_bins <= 0:
raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
if lower_edge_hertz < 0.0:
raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
lower_edge_hertz)
if lower_edge_hertz >= upper_edge_hertz:
raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
(lower_edge_hertz, upper_edge_hertz))
if sample_rate <= 0.0:
raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
if upper_edge_hertz > sample_rate / 2:
raise ValueError('upper_edge_hertz must not be larger than the Nyquist '
'frequency (sample_rate / 2). Got %s for sample_rate: %s' %
(upper_edge_hertz, sample_rate))
# HTK excludes the spectrogram DC bin.
bands_to_zero = 1
nyquist_hertz = sample_rate / 2.0
linear_frequencies = jnp.linspace(
0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype)[bands_to_zero:]
spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, jnp.newaxis]
# Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
# center of each band is the lower and upper edge of the adjacent bands.
# Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
# num_mel_bins + 2 pieces.
edges = jnp.linspace(
_hertz_to_mel(lower_edge_hertz),
_hertz_to_mel(upper_edge_hertz),
num_mel_bins + 2,
dtype=dtype)
# Split the triples up and reshape them into [1, num_mel_bins] tensors.
lower_edge_mel = edges[:-2][jnp.newaxis, :]
center_mel = edges[1:-1][jnp.newaxis, :]
upper_edge_mel = edges[2:][jnp.newaxis, :]
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the mel domain, not Hertz.
lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
center_mel - lower_edge_mel)
upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
upper_edge_mel - center_mel)
# Intersect the line segments with each other and zero.
mel_weights_matrix = jnp.maximum(0.0, jnp.minimum(lower_slopes, upper_slopes))
# Re-add the zeroed lower bins we sliced out above.
return jnp.pad(mel_weights_matrix, [[bands_to_zero, 0], [0, 0]])
def _hanning_greco(win_support, frame_size, dtype):
"""Add a greco-style hanning window to the graph.
Note that the Hanning window in Wikipedia is not the same as the Hanning
window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia
page would indicate. Talkin's explanation was that it was like wasting two
samples to have the values at the edge of the window to be 0.0 exactly.
Args:
win_support: Number of samples for non-zero support in the window
frame_size: Total size of the window (frame_size >= win_support)
dtype: TF data type
Returns:
Tensor of size frame_size with the window to apply.
"""
if frame_size < win_support:
raise ValueError(
'Provided frame_size = {} is lower than win_support = {}'.format(
frame_size, win_support))
arg = jnp.pi * 2.0 / (win_support)
hann = 0.5 - (0.5 * jnp.cos(arg *
(jnp.arange(win_support, dtype=dtype) + 0.5)))
zero_size = frame_size - win_support
return jnp.pad(hann, [(0, zero_size)])
def _next_pow_of_two(x: Union[int, float]) -> int:
return int(2**np.ceil(np.log2(x)))
class SpectrogramFrontend(nn.Module):
"""Layer to convert input audio signals from time domain to frequency domain.
"""
config: LibrispeechPreprocessingConfig = None
input_scale_factor: float = 1.0
output_log: bool = False
def setup(self) -> None:
p = self.config
self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0))
self._frame_size = int(round(
p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph
# TF-version has maximum of 512, but it's not always necessary
self.fft_size = _next_pow_of_two(self._frame_size)
if p.window_fn is None:
self._window_fn = None
elif p.window_fn.upper() == 'HANNING':
def _hanning_window(frame_size, dtype):
# Preparing 1-point longer window to follow TF's definition
if frame_size % 2 == 0:
# simulate periodic=True in tf.signal.hann_window
return jnp.hanning(frame_size + 1).astype(dtype)[:-1]
else:
return jnp.hanning(frame_size).astype(dtype)
self._window_fn = _hanning_window
elif p.window_fn.upper() == 'HANNING_GRECO':
# Greco-compatible hanning window
def f(frame_size, dtype):
return _hanning_greco(self._frame_size - 1, frame_size, dtype)
self._window_fn = f
else:
raise ValueError('Illegal value %r for window_fn param' % p.window_fn)
def _apply_preemphasis(self, framed_signal):
p = self.config
if p.preemph_htk_flavor:
return jnp.concatenate([
framed_signal[:, :, :1, :] * (1. - p.preemph),
(framed_signal[:, :, 1:-1, :] -
p.preemph * framed_signal[:, :, :-2, :])
],
axis=2)
else:
return (framed_signal[:, :, 1:, :] -
p.preemph * framed_signal[:, :, :-1, :])
def fprop_paddings(self, input_paddings):
p = self.config
if p.pad_end:
num_extends = _pad_end_length(input_paddings.shape[1],
self._frame_step,
self._frame_size)
input_paddings = jnp.pad(
input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0)
return jax.lax.reduce_window(
input_paddings,
init_value=1.0,
computation=jax.lax.min,
window_dimensions=[1, self._frame_size],
window_strides=[1, self._frame_step],
padding='valid')
def next_prng_key(self, name='dropout'):
return self.make_rng(name)
@nn.compact
def __call__(self, inputs, input_paddings):
inputs = inputs.astype(jnp.float32)
p = self.config
# Expand to have a channel axis
if inputs.ndim == 2:
inputs = jnp.expand_dims(inputs, -1)
output_paddings = None
if input_paddings is not None:
inputs = inputs * jnp.expand_dims(1.0 - input_paddings, -1)
output_paddings = self.fprop_paddings(input_paddings)
else:
output_paddings = None
pcm_audio_chunk = inputs.astype(jnp.float32) * self.input_scale_factor
framed_signal = frame(
pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end)
if p.preemph != 0.0:
preemphasized = self._apply_preemphasis(framed_signal)
else:
preemphasized = framed_signal[..., :-1, :]
if p.noise_scale > 0.0:
noise_signal = jax.random.normal(self.next_prng_key(),
preemphasized.shape) * p.noise_scale
else:
noise_signal = jnp.zeros(preemphasized.shape)
windowed_signal = preemphasized + noise_signal
# Window here
if self._window_fn is not None:
window = self._window_fn(self._frame_size - 1, framed_signal.dtype)
window = window.reshape((1, 1, self._frame_size - 1, 1))
windowed_signal *= window
spectrum = jnp.fft.rfft(windowed_signal, self.fft_size, axis=2)
spectrum = jnp.abs(spectrum)
if p.compute_energy:
spectrum = spectrum**2.0
outputs = spectrum
if self.output_log:
outputs = jnp.log(jnp.maximum(outputs, p.output_log_floor))
return outputs, output_paddings
class MelFilterbankFrontend(nn.Module):
"""Layer to compute log mel spectograms from input audio signals.
"""
config: LibrispeechPreprocessingConfig = None
use_divide_stream: bool = True
per_bin_mean: Optional[float] = None
per_bin_stddev: Optional[float] = None
def setup(self):
p = self.config
input_scale_factor = 2**-15 if self.use_divide_stream else 1.0
self.stft = SpectrogramFrontend(
p, input_scale_factor=input_scale_factor, output_log=False)
if self.per_bin_mean is None:
per_bin_mean = [0.0] * p.num_bins
else:
per_bin_mean = self.per_bin_mean
if self.per_bin_stddev is None:
per_bin_stddev = [1.0] * p.num_bins
else:
per_bin_stddev = self.per_bin_stddev
self._normalizer_mean = jnp.array(per_bin_mean)[
jnp.newaxis, jnp.newaxis, :, jnp.newaxis]
self._normalizer_stddev = jnp.array(per_bin_stddev)[
jnp.newaxis, jnp.newaxis, :, jnp.newaxis]
@nn.compact
def __call__(self, inputs, input_paddings):
p = self.config
spect, spect_paddings = self.stft(inputs, input_paddings)
mel_weights = linear_to_mel_weight_matrix(
num_mel_bins=p.num_bins,
num_spectrogram_bins=spect.shape[2],
sample_rate=p.sample_rate,
lower_edge_hertz=p.lower_edge_hertz,
upper_edge_hertz=p.upper_edge_hertz)
mel_spectrogram = jnp.einsum('fn,btfc->btnc', mel_weights, spect)
logmel_spectrogram = jnp.log(jnp.maximum(mel_spectrogram, p.output_floor))
normalized_logmel_spectrogram = (
(logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev)
normalized_logmel_spectrogram = jnp.squeeze(normalized_logmel_spectrogram,
-1)
return normalized_logmel_spectrogram, spect_paddings
#@title SpecAug Layer
"""A flax layer to do data augmentation for audio signals as
described in https://arxiv.org/abs/1904.08779.
Code based on:
github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/spectrum_augmenter.py
"""
import flax.linen as nn
import jax
import jax.numpy as jnp
class SpecAug(nn.Module):
"""Layer performs masking prodecure along time and frequency axis.
The procedure is detailed in https://arxiv.org/abs/1904.08779.
This is an essential component in speech recognition models that helps achieve
better word error rates.
"""
freq_mask_count: int = 2
freq_mask_max_bins: int = 27
time_mask_count: int = 10
time_mask_max_frames: int = 40
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
def next_prng_key(self, name='dropout'):
return self.make_rng(name)
def _get_mask(self,
batch_size,
choose_range,
mask_size,
max_length=None,
masks_per_frame=0.0,
multiplicity=1,
max_ratio=1.0):
# Sample lengths for multiple masks.
if max_length and max_length > 0:
max_length = jnp.tile(max_length, (batch_size,))
else:
max_length = choose_range * max_ratio
masked_portion = jax.random.uniform(
key=self.next_prng_key(),
shape=(batch_size, multiplicity),
minval=0.0,
maxval=1.0)
masked_frame_size = jnp.einsum('b,bm->bm', max_length,
masked_portion).astype(jnp.int32)
# Make sure the sampled length was smaller than max_ratio * length_bound.
# Note that sampling in this way was biased
# (shorter sequence may over-masked.)
choose_range = jnp.tile(choose_range[:, None], [1, multiplicity])
length_bound = (max_ratio * choose_range).astype(jnp.int32)
length = jnp.minimum(masked_frame_size, jnp.maximum(length_bound, 1))
# Choose starting point.
random_start = jax.random.uniform(
key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0)
start_with_in_valid_range = random_start * (choose_range - length + 1)
start = start_with_in_valid_range.astype(jnp.int32)
end = start + length - 1
# Shift starting and end point by small value.
delta = 0.1
start = jnp.expand_dims(start - delta, -1)
start = jnp.tile(start, [1, 1, mask_size])
end = jnp.expand_dims(end + delta, -1)
end = jnp.tile(end, [1, 1, mask_size])
# Construct pre-mask of shape (batch_size, multiplicity, mask_size).
diagonal = jnp.expand_dims(jnp.expand_dims(jnp.arange(mask_size), 0), 0)
diagonal = jnp.tile(diagonal, [batch_size, multiplicity, 1])
pre_mask = jnp.minimum(diagonal < end, diagonal > start)
# Sum masks with appropriate multiplicity.
if masks_per_frame > 0:
multiplicity_weights = jnp.tile(
jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
[batch_size, 1])
multiplicity_tensor = masks_per_frame * choose_range
multiplicity_weights = (multiplicity_weights <
multiplicity_tensor).astype(jnp.int32)
pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
else:
pre_mask = jnp.einsum('bmt->bt', pre_mask)
mask = 1.0 - (pre_mask > 0).astype(jnp.int32)
return mask
def _time_mask(self, inputs, length):
# Get time masking parameters.
time_mask_max_frames = self.time_mask_max_frames
use_dynamic_time_mask_max_frames = self.use_dynamic_time_mask_max_frames
multiplicity = self.time_mask_count
max_ratio = self.time_mask_max_ratio
# If maximum mask length is zero, do nothing.
if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or
max_ratio <= 0.0):
return inputs
if multiplicity == 0:
return inputs
batch_size, time_length, _ = inputs.shape
# When using dynamic time mask size, discard upper-bound on
# maximum allowed frames for time mask.
if use_dynamic_time_mask_max_frames:
time_mask_max_frames = None
# Create masks in time direction and apply.
block_arrays = self._get_mask(
batch_size,
choose_range=length,
mask_size=time_length,
max_length=time_mask_max_frames,
masks_per_frame=self.time_masks_per_frame,
multiplicity=multiplicity,
max_ratio=max_ratio)
outputs = jnp.einsum('bxy,bx->bxy', inputs, block_arrays)
return outputs
def _frequency_mask(self, inputs):
# Mask parameters.
freq_mask_max_bins = self.freq_mask_max_bins
multiplicity = self.freq_mask_count
# If masking length or count is zero, do nothing.
if freq_mask_max_bins == 0 or multiplicity == 0:
return inputs
# Arguments to pass to mask generator.
batch_size, _, num_freq = inputs.shape
choose_range = jnp.tile(num_freq, (batch_size,))
# Create masks in frequency direction and apply.
block_arrays = self._get_mask(
batch_size,
choose_range=choose_range,
mask_size=num_freq,
max_length=freq_mask_max_bins,
masks_per_frame=0.0,
multiplicity=multiplicity,
max_ratio=1.0)
outputs = jnp.einsum('bxy,by->bxy', inputs, block_arrays)
return outputs
@nn.compact
def __call__(self, inputs, paddings):
lengths = jnp.einsum('bh->b', 1 - paddings).astype(jnp.int32)
inputs = self._time_mask(inputs, lengths)
inputs = self._frequency_mask(inputs)
return inputs, paddings
#@title CudnnLSTM Layer
from typing import Any, Optional, Sequence, Tuple, Union
from flax import linen as nn
import jax
from jax.experimental import rnn
import jax.numpy as jnp
import numpy as np
Array = jnp.ndarray
StateType = Union[Array, Tuple[Array, ...]]
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
class CudnnLSTM(nn.Module):
input_size: int
hidden_size: int
num_layers: int
dropout_rate: float = 0.0
bidirectional: bool = False
def setup(self):
self.w = self.param(
'lstm_weights',
rnn.init_lstm_weight,
self.input_size,
self.hidden_size,
self.num_layers,
self.bidirectional,
)
def __call__(
self,
inputs: Array,
input_paddings: Array,
initial_states: Optional[Sequence[StateType]] = None,
deterministic: bool = False,
) -> Tuple[Array, Sequence[StateType]]:
# TODO(zhangqiaorjc): initial_states
assert initial_states is None
num_directions = 2 if self.bidirectional else 1
batch_size = inputs.shape[0]
dropout = 0.0 if deterministic else self.dropout_rate
h_0 = jnp.zeros(
(num_directions * self.num_layers, batch_size, self.hidden_size),
jnp.float32,
)
c_0 = jnp.zeros(
(num_directions * self.num_layers, batch_size, self.hidden_size),
jnp.float32,
)
seq_lengths = jnp.sum(1.0 - input_paddings, axis=-1, dtype=jnp.int32)
# def lstm(input, h_0, c_0, weights, input_size: int, hidden_size: int,
# num_layers: int, dropout: float, bidirectional: bool):
y, _, _ = rnn.lstm(
inputs,
h_0,
c_0,
self.w,
seq_lengths,
self.input_size,
self.hidden_size,
self.num_layers,
dropout,
self.bidirectional,
)
return y
#@title Deepspeech Model
r"""Deepspeech.
This model uses a deepspeech2 network to convert speech to text.
paper : https://arxiv.org/abs/1512.02595
# BiLSTM code contributed by bastings@
# github : https://github.com/bastings
# webpage : https://bastings.github.io/
"""
import functools
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Type, Union
import flax
from flax import linen as nn
from flax import struct
import jax
import jax.numpy as jnp
Array = jnp.ndarray
StateType = Union[Array, Tuple[Array, ...]]
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
@struct.dataclass
class DeepspeechConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
vocab_size: int = 1024
dtype: Any = jnp.float32
encoder_dim: int = 512
num_lstm_layers: int = 6
num_ffn_layers: int = 3
conv_subsampling_factor: int = 2
conv_subsampling_layers: int = 2
use_specaug: bool = True
freq_mask_count: int = 2
freq_mask_max_bins: int = 27
time_mask_count: int = 10
time_mask_max_frames: int = 40
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
# If None, defaults to 0.1.
input_dropout_rate: Optional[float] = 0.1
# If None, defaults to 0.1.
feed_forward_dropout_rate: Optional[float] = 0.1
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
use_cudnn_lstm: bool = False
class Subsample(nn.Module):
"""Module to perform strided convolution in order to subsample inputs.
Attributes:
encoder_dim: model dimension of conformer.
input_dropout_rate: dropout rate for inputs.
"""
config: DeepspeechConfig
@nn.compact
def __call__(self, inputs, output_paddings, train):
config = self.config
outputs = jnp.expand_dims(inputs, axis=-1)
outputs, output_paddings = Conv2dSubsampling(
encoder_dim=config.encoder_dim,
dtype=config.dtype,
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon,
input_channels=1,
output_channels=config.encoder_dim)(outputs, output_paddings, train)
outputs, output_paddings = Conv2dSubsampling(
encoder_dim=config.encoder_dim,
dtype=config.dtype,
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon,
input_channels=config.encoder_dim,
output_channels=config.encoder_dim)(outputs, output_paddings, train)
batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
outputs = jnp.reshape(
outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
outputs = nn.Dense(
config.encoder_dim,
use_bias=True,
kernel_init=nn.initializers.xavier_uniform())(
outputs)
if config.input_dropout_rate is None:
input_dropout_rate = 0.1
else:
input_dropout_rate = config.input_dropout_rate
outputs = nn.Dropout(
rate=input_dropout_rate, deterministic=not train)(
outputs)
return outputs, output_paddings
class Conv2dSubsampling(nn.Module):
"""Helper module used in Subsample layer.
1) Performs strided convolution over inputs and then applies non-linearity.
2) Also performs strided convolution over input_paddings to return the correct
paddings for downstream layers.
"""
input_channels: int = 0
output_channels: int = 0
filter_stride: List[int] = (2, 2)
padding: str = 'SAME'
encoder_dim: int = 0
dtype: Any = jnp.float32
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
def setup(self):
self.filter_shape = (3, 3, self.input_channels, self.output_channels)
self.kernel = self.param('kernel',
nn.initializers.xavier_uniform(),
self.filter_shape)
self.bias = self.param(
'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
@nn.compact
def __call__(self, inputs, paddings, train):
# Computing strided convolution to subsample inputs.
feature_group_count = inputs.shape[3] // self.filter_shape[2]
outputs = jax.lax.conv_general_dilated(
lhs=inputs,
rhs=self.kernel,
window_strides=self.filter_stride,
padding=self.padding,
rhs_dilation=(1, 1),
dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
feature_group_count=feature_group_count)
outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
outputs = nn.relu(outputs)
# Computing correct paddings post input convolution.
input_length = paddings.shape[1]
stride = self.filter_stride[0]
pad_len = (input_length + stride - 1) // stride * stride - input_length
out_padding = jax.lax.conv_general_dilated(
lhs=paddings[:, :, None],
rhs=jnp.ones([1, 1, 1]),
window_strides=self.filter_stride[:1],
padding=[(0, pad_len)],
dimension_numbers=('NHC', 'HIO', 'NHC'))
out_padding = jnp.squeeze(out_padding, axis=-1)
# Mask outputs by correct paddings to ensure padded elements in inputs map
# to padded value in outputs.
outputs = outputs * (1.0 -
jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
return outputs, out_padding
class FeedForwardModule(nn.Module):
"""Feedforward block of conformer layer."""
config: DeepspeechConfig
@nn.compact
def __call__(self, inputs, input_paddings=None, train=False):
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
config = self.config
inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
inputs = nn.Dense(
config.encoder_dim,
use_bias=True,
kernel_init=nn.initializers.xavier_uniform())(
inputs)
inputs = nn.relu(inputs)
inputs *= padding_mask
if config.feed_forward_dropout_rate is None:
feed_forward_dropout_rate = 0.1
else:
feed_forward_dropout_rate = config.feed_forward_dropout_rate
inputs = nn.Dropout(rate=feed_forward_dropout_rate)(
inputs, deterministic=not train)
return inputs
class LayerNorm(nn.Module):
"""Module implementing layer normalization.
This implementation is same as in this paper:
https://arxiv.org/pdf/1607.06450.pdf.
note: we multiply normalized inputs by (1 + scale) and initialize scale to
zeros, this differs from default flax implementation of multiplying by scale
and initializing to ones.
"""
dim: int = 0
epsilon: float = 1e-6
def setup(self):
self.scale = self.param('scale', nn.initializers.zeros, [self.dim])
self.bias = self.param('bias', nn.initializers.zeros, [self.dim])
@nn.compact
def __call__(self, inputs):
mean = jnp.mean(inputs, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
normed_inputs *= (1 + self.scale)
normed_inputs += self.bias
return normed_inputs
class BatchNorm(nn.Module):
"""Implements batch norm respecting input paddings.
This implementation takes into account input padding by masking inputs before
computing mean and variance.
This is inspired by lingvo jax implementation of BatchNorm:
https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92
and the corresponding defaults for momentum and epsilon have been copied over
from lingvo.
"""
encoder_dim: int = 0
dtype: Any = jnp.float32
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
def setup(self):
dim = self.encoder_dim
dtype = self.dtype
self.ra_mean = self.variable('batch_stats',
'mean',
lambda s: jnp.zeros(s, dtype),
dim)
self.ra_var = self.variable('batch_stats',
'var',
lambda s: jnp.ones(s, dtype),
dim)
self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
def _get_default_paddings(self, inputs):
"""Gets the default paddings for an input."""
in_shape = list(inputs.shape)
in_shape[-1] = 1
return jnp.zeros(in_shape, dtype=inputs.dtype)
@nn.compact
def __call__(self, inputs, input_paddings=None, train=False):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))
if input_paddings is None:
padding = self._get_default_paddings(inputs)
else:
padding = jnp.expand_dims(input_paddings, -1)
momentum = self.batch_norm_momentum
epsilon = self.batch_norm_epsilon
if train:
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
sum_v = jax.lax.psum(sum_v, axis_name='batch')
count_v = jax.lax.psum(count_v, axis_name='batch')
count_v = jnp.maximum(count_v, 1.0)
mean = sum_v / count_v
variance = (inputs - mean) * (inputs - mean) * mask
sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True)
sum_vv = jax.lax.psum(sum_vv, axis_name='batch')
var = sum_vv / count_v
self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var
else:
mean = self.ra_mean.value
var = self.ra_var.value
inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
bn_output = (inputs - mean) * inv + self.beta
bn_output *= 1.0 - padding
return bn_output
# return inputs
@jax.vmap
def flip_sequences(inputs: Array, lengths: Array) -> Array:
"""Flips a sequence of inputs along the time dimension.
This function can be used to prepare inputs for the reverse direction of a
bidirectional LSTM. It solves the issue that, when naively flipping multiple
padded sequences stored in a matrix, the first elements would be padding
values for those sequences that were padded. This function keeps the padding
at the end, while flipping the rest of the elements.
Example:
```python
inputs = [[1, 0, 0],
[2, 3, 0]
[4, 5, 6]]
lengths = [1, 2, 3]
flip_sequences(inputs, lengths) = [[1, 0, 0],
[3, 2, 0],
[6, 5, 4]]
```
Args:
inputs: An array of input IDs <int>[batch_size, seq_length].
lengths: The length of each sequence <int>[batch_size].
Returns:
An ndarray with the flipped inputs.
"""
# Compute the indices to put the inputs in flipped order as per above example.
max_length = inputs.shape[0]
idxs = (jnp.arange(max_length - 1, -1, -1) + lengths) % max_length
return inputs[idxs]
class GenericRNNSequenceEncoder(nn.Module):
"""Encodes a single sequence using any RNN cell, for example `nn.LSTMCell`.
The sequence can be encoded left-to-right (default) or right-to-left (by
calling the module with reverse=True). Regardless of encoding direction,
outputs[i, j, ...] is the representation of inputs[i, j, ...].
Attributes:
hidden_size: The hidden size of the RNN cell.
cell_type: The RNN cell module to use, for example, `nn.LSTMCell`.
cell_kwargs: Optional keyword arguments for the recurrent cell.
recurrent_dropout_rate: The dropout to apply across time steps. If this is
greater than zero, you must use an RNN cell that implements
`RecurrentDropoutCell` such as RecurrentDropoutOptimizedLSTMCell.
"""
hidden_size: int
cell_type: Type[nn.recurrent.RNNCellBase]
cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict()
recurrent_dropout_rate: float = 0.0
def setup(self):
self.cell = self.cell_type(**self.cell_kwargs)
@functools.partial( # Repeatedly calls the below method to encode the inputs.
nn.transforms.scan,
variable_broadcast='params',
in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast),
out_axes=1,
split_rngs={'params': False})
def unroll_cell(self,
cell_state: StateType,
inputs: Array,
recurrent_dropout_mask: Optional[Array],
deterministic: bool):
"""Unrolls a recurrent cell over an input sequence.
Args:
cell_state: The initial cell state, shape: <float32>[batch_size,
hidden_size] (or an n-tuple thereof).
inputs: The input sequence. <float32>[batch_size, seq_len, input_dim].
recurrent_dropout_mask: An optional recurrent dropout mask to apply in
between time steps. <float32>[batch_size, hidden_size].
deterministic: Disables recurrent dropout when set to True.
Returns:
The cell state after processing the complete sequence (including padding),
and a tuple with all intermediate cell states and cell outputs.
"""
# We do not directly scan the cell itself, since it only returns the output.
# This returns both the state and the output, so we can slice out the
# correct final states later.
new_cell_state, output = self.cell(cell_state, inputs)
return new_cell_state, (new_cell_state, output)
def __call__(self,
inputs: Array,
input_paddings: Array,
initial_state: StateType,
reverse: bool = False,
deterministic: bool = False):
"""Unrolls the RNN cell over the inputs.
Arguments:
inputs: A batch of sequences. Shape: <float32>[batch_size, seq_len,
input_dim].
lengths: The lengths of the input sequences.
initial_state: The initial state for the RNN cell. Shape: [batch_size,
hidden_size].
reverse: Process the inputs in reverse order, and reverse the outputs.
This means that the outputs still correspond to the order of the inputs,
but their contexts come from the right, and not from the left.
deterministic: Disables recurrent dropout if set to True.
Returns:
The encoded sequence of inputs, shaped <float32>[batch_size, seq_len,
hidden_size], as well as the final hidden states of the RNN cell. For an
LSTM cell the final states are a tuple (c, h), each shaped <float32>[
batch_size, hidden_size].
"""
lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32)
if reverse:
inputs = flip_sequences(inputs, lengths)
recurrent_dropout_mask = None
_, (_, outputs) = self.unroll_cell(initial_state,
inputs,
recurrent_dropout_mask,
deterministic)
if reverse:
outputs = flip_sequences(outputs, lengths)
return outputs
class GenericRNN(nn.Module):
"""Generic RNN class.
This provides generic RNN functionality to encode sequences with any RNN cell.
The class provides unidirectional and bidirectional layers, and these are
stacked when asking for multiple layers.
This class be used to create a specific RNN class such as LSTM or GRU.
Attributes:
cell_type: An RNN cell class to use, e.g., `flax.linen.LSTMCell`.
hidden_size: The size of each recurrent cell.
num_layers: The number of stacked recurrent layers. The output of the first
layer, with optional dropout applied, feeds into the next layer.
dropout_rate: Dropout rate to be applied between LSTM layers. Only applies
when num_layers > 1.
recurrent_dropout_rate: Dropout rate to be applied on the hidden state at
each time step repeating the same dropout mask.
bidirectional: Process the sequence left-to-right and right-to-left and
concatenate the outputs from the two directions.
cell_kwargs: Optional keyword arguments to instantiate the cell with.
"""
cell_type: Type[nn.recurrent.RNNCellBase]
hidden_size: int
num_layers: int = 1
dropout_rate: float = 0.
recurrent_dropout_rate: float = 0.
bidirectional: bool = False
cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict()
@nn.compact
def __call__(
self,
inputs: Array,
input_paddings: Array,
initial_states: Optional[Sequence[StateType]] = None,
deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]:
"""Processes the input sequence using the recurrent cell.
Args:
inputs: The input sequence <float32>[batch_size, sequence_length, ...]
lengths: The lengths of each sequence in the batch. <int64>[batch_size]
initial_states: The initial states for the cells. You must provide
`num_layers` initial states (when using bidirectional, `num_layers *
2`).
These must be ordered in the following way: (layer_0_forward,
layer_0_backward, layer_1_forward, layer_1_backward, ...). If None,
all initial states will be initialized with zeros.
deterministic: Disables dropout between layers when set to True.
Returns:
The sequence of all outputs for the final layer, and a list of final
states for each cell and direction. Directions are alternated (first
forward, then backward, if bidirectional). For example for a bidirectional
cell this would be: layer 1 forward, layer 1 backward, layer 2 forward,
layer 2 backward, etc..
For some cells like LSTMCell a state consists of an (c, h) tuple, while
for others cells it only contains a single vector (h,).
"""
batch_size = inputs.shape[0]
num_directions = 2 if self.bidirectional else 1
num_cells = self.num_layers * num_directions
# Construct initial states.
if initial_states is None: # Initialize with zeros.
rng = jax.random.PRNGKey(0)
initial_states = [
self.cell_type.initialize_carry(rng, (batch_size,), self.hidden_size)
for _ in range(num_cells)
]
if len(initial_states) != num_cells:
raise ValueError(
f'Please provide {self.num_cells} (`num_layers`, *2 if bidirectional)'
f'initial states.')
# For each layer, apply the forward and optionally the backward RNN cell.
cell_idx = 0
for _ in range(self.num_layers):
# Unroll an RNN cell (forward direction) for this layer.
outputs = GenericRNNSequenceEncoder(
cell_type=self.cell_type,
cell_kwargs=self.cell_kwargs,
hidden_size=self.hidden_size,
recurrent_dropout_rate=self.recurrent_dropout_rate,
name=f'{self.name}SequenceEncoder_{cell_idx}')(
inputs,
input_paddings,
initial_state=initial_states[cell_idx],
deterministic=deterministic)
cell_idx += 1
# Unroll an RNN cell (backward direction) for this layer.
if self.bidirectional:
backward_outputs = GenericRNNSequenceEncoder(
cell_type=self.cell_type,
cell_kwargs=self.cell_kwargs,
hidden_size=self.hidden_size,
recurrent_dropout_rate=self.recurrent_dropout_rate,
name=f'{self.name}SequenceEncoder_{cell_idx}')(
inputs,
input_paddings,
initial_state=initial_states[cell_idx],
reverse=True,
deterministic=deterministic)
outputs = jnp.concatenate([outputs, backward_outputs], axis=-1)
cell_idx += 1
inputs = outputs
return outputs
class LSTM(nn.Module):
"""LSTM.
Attributes:
hidden_size: The size of each recurrent cell.
num_layers: The number of stacked recurrent layers. The output of the first
layer, with optional dropout applied, feeds into the next layer.
dropout_rate: Dropout rate to be applied between LSTM layers. Only applies
when num_layers > 1.
recurrent_dropout_rate: Dropout rate to be applied on the hidden state at
each time step repeating the same dropout mask.
bidirectional: Process the sequence left-to-right and right-to-left and
concatenate the outputs from the two directions.
cell_type: The LSTM cell class to use. Default:
`flax.linen.OptimizedLSTMCell`. If you use hidden_size of >2048, consider
using `flax.linen.LSTMCell` instead, since the optimized LSTM cell works
best for hidden sizes up to 2048.
cell_kwargs: Optional keyword arguments to instantiate the cell with.
"""
hidden_size: int
num_layers: int = 1
dropout_rate: float = 0.
recurrent_dropout_rate: float = 0.
bidirectional: bool = False
cell_type: Any = nn.OptimizedLSTMCell
cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict()
@nn.compact
def __call__(
self,
inputs: Array,
input_paddings: Array,
initial_states: Optional[Sequence[StateType]] = None,
deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]:
"""Processes an input sequence with an LSTM cell.
Example usage:
```
inputs = np.random.normal(size=(2, 3, 4))
lengths = np.array([1, 3])
outputs, final_states = LSTM(hidden_size=10).apply(rngs, inputs, lengths)
```
Args:
inputs: The input sequence <float32>[batch_size, sequence_length, ...]
lengths: The lengths of each sequence in the batch. <int64>[batch_size]
initial_states: The initial states for the cells. You must provide
`num_layers` initial states (when using bidirectional, `num_layers *
2`). These must be ordered in the following way: (layer_0_forward,
layer_0_backward, layer_1_forward, layer_1_backward, ...). If None,
all initial states will be initialized with zeros.
deterministic: Disables dropout between layers when set to True.
Returns:
The sequence of all outputs for the final layer, and a list of final
states (h, c) for each cell and direction, ordered first by layer number
and then by direction (first forward, then backward, if bidirectional).
"""
return GenericRNN(
cell_type=self.cell_type,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
dropout_rate=self.dropout_rate,
recurrent_dropout_rate=self.recurrent_dropout_rate,
bidirectional=self.bidirectional,
cell_kwargs=self.cell_kwargs,
name='LSTM')(
inputs,
input_paddings,
initial_states=initial_states,
deterministic=deterministic)
class BatchRNN(nn.Module):
"""Implements a single deepspeech encoder layer.
"""
config: DeepspeechConfig
@nn.compact
def __call__(self, inputs, input_paddings, train):
config = self.config
inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
if config.use_cudnn_lstm:
output = CudnnLSTM(
input_size = config.encoder_dim,
hidden_size=config.encoder_dim // 2 if config.bidirectional else config.encoder_dim,
bidirectional=config.bidirectional,
num_layers=1)(inputs, input_paddings)
else:
output = LSTM(
hidden_size=config.encoder_dim // 2 if config.bidirectional else config.encoder_dim,
bidirectional=config.bidirectional,
num_layers=1)(inputs, input_paddings)
return output
class Deepspeech(nn.Module):
"""Conformer (encoder + decoder) block.
Takes audio input signals and outputs probability distribution over vocab size
for each time step. The output is then fed into a CTC loss which eliminates
the need for alignment with targets.
"""
config: DeepspeechConfig
def setup(self):
config = self.config
self.specaug = SpecAug(
freq_mask_count=config.freq_mask_count,
freq_mask_max_bins=config.freq_mask_max_bins,
time_mask_count=config.time_mask_count,
time_mask_max_frames=config.time_mask_max_frames,
time_mask_max_ratio=config.time_mask_max_ratio,
time_masks_per_frame=config.time_masks_per_frame,
use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
)
@nn.compact
def __call__(self, inputs, input_paddings, train):
config = self.config
outputs = inputs
output_paddings = input_paddings
# Compute normalized log mel spectrograms from input audio signal.
preprocessing_config = LibrispeechPreprocessingConfig()
outputs, output_paddings = MelFilterbankFrontend(
preprocessing_config,
per_bin_mean=LIBRISPEECH_MEAN_VECTOR,
per_bin_stddev=LIBRISPEECH_STD_VECTOR)(outputs, output_paddings)
# Ablate random parts of input along temporal and frequency dimension
# following the specaug procedure in https://arxiv.org/abs/1904.08779.
if config.use_specaug and train:
outputs, output_paddings = self.specaug(outputs, output_paddings)
# Subsample input by a factor of 4 by performing strided convolutions.
outputs, output_paddings = Subsample(
config=config)(outputs, output_paddings, train)
# Run the lstm layers.
for _ in range(config.num_lstm_layers):
if config.enable_residual_connections:
outputs = outputs + BatchRNN(config)(outputs, output_paddings, train)
else:
outputs = BatchRNN(config)(outputs, output_paddings, train)
for _ in range(config.num_ffn_layers):
if config.enable_residual_connections:
outputs = outputs + FeedForwardModule(config=self.config)(
outputs, output_paddings, train)
else:
outputs = FeedForwardModule(config=self.config)(outputs,
output_paddings,
train)
# Run the decoder which in this case is a trivial projection layer.
if config.enable_decoder_layer_norm:
outputs = LayerNorm(config.encoder_dim)(outputs)
outputs = nn.Dense(
config.vocab_size,
use_bias=True,
kernel_init=nn.initializers.xavier_uniform())(
outputs)
return outputs, output_paddings
BATCH_SIZE = 128
USE_CUDNN_LSTM=False
#@title Pmapped Train Loop 1 step
import jax
import numpy as np
import functools
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax import jax_utils
import optax
from absl import logging
import jax.lax as lax
import time
from absl import app
from absl import flags
from absl import logging
import os
_GRAD_CLIP_EPS = 1e-6
def shard(batch, n_devices=None):
"""Prepares the batch for pmap by adding a leading n_devices dimension.
If all the entries are lists, assume they are already divided into n_devices
smaller arrays and stack them for pmapping. If all the entries are arrays,
assume they have leading dimension divisible by n_devices and reshape.
Args:
batch: A dict of arrays or lists of arrays
n_devices: If None, this will be set to jax.local_device_count().
Returns:
Sharded data.
"""
if n_devices is None:
n_devices = jax.local_device_count()
# TODO(mbadura): Specify a sharding function per dataset instead
# If entries in the batch dict are lists, then the data is already divided
# into n_devices chunks, so we need to stack them.
if all((isinstance(v, list) for v in batch.values())):
assert all(len(v) == n_devices for v in batch.values())
# transpose a dict of lists to a list of dicts
shards = [{k: v[i] for (k, v) in batch.items()} for i in range(n_devices)]
return jax.tree_map(lambda *vals: np.stack(vals, axis=0), shards[0],
*shards[1:])
# Otherwise, the entries are arrays, so just reshape them.
def _shard_array(array):
return array.reshape((n_devices, -1) + array.shape[1:])
return jax.tree_map(_shard_array, batch)
def load_dummy_batch():
batch_size = BATCH_SIZE
inputs = np.zeros((batch_size, 320000))
input_paddings = np.zeros((batch_size, 320000))
targets = np.zeros((batch_size, 256))
target_paddings = np.zeros((batch_size, 256))
padded_batch = {
'inputs': (jnp.array(inputs), jnp.array(input_paddings)),
'targets': (jnp.array(targets), jnp.array(target_paddings))
}
sharded_padded_batch = shard(padded_batch)
inputs, input_paddings = sharded_padded_batch['inputs']
print(inputs.shape, input_paddings.shape)
return sharded_padded_batch
# Initing optimizer and LR schedule
def jax_cosine_warmup():
# Create learning rate schedule.
warmup_fn = optax.linear_schedule(
init_value=0.,
end_value=0.02,
transition_steps=5000)
cosine_steps = max(60000 - 5000, 1)
cosine_fn = optax.cosine_decay_schedule(
init_value=0.02, decay_steps=cosine_steps)
schedule_fn = optax.join_schedules(
schedules=[warmup_fn, cosine_fn],
boundaries=[500])
return schedule_fn
def init_optimizer_state(params):
"""Creates an AdamW optimizer and a learning rate schedule."""
lr_schedule_fn = jax_cosine_warmup()
# Create optimizer.
epsilon = (1e-8)
opt_init_fn, opt_update_fn = optax.adamw(
learning_rate=lr_schedule_fn,
b1=0.98,
b2=0.99,
eps=epsilon,
weight_decay=0.0)
optimizer_state = opt_init_fn(params)
return jax_utils.replicate(optimizer_state), opt_update_fn
def train_step(model_class,
opt_update_fn,
params,
batch_stats,
optimizer_state,
batch,
rng,
grad_clip):
def _loss_fn(params):
"""Loss function used for training."""
inputs, input_paddings = batch['inputs']
targets, target_paddings = batch['targets']
(logits, logit_paddings), updated_vars = model_class.apply(
{'params': params, 'batch_stats': batch_stats},
inputs,
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
new_batch_stats = updated_vars['batch_stats']
logprobs = nn.log_softmax(logits)
per_seq_loss = optax.ctc_loss(logprobs,
logit_paddings,
targets,
target_paddings)
normalizer = jnp.sum(1 - target_paddings)
normalized_loss = jnp.sum(per_seq_loss) / jnp.maximum(normalizer, 1)
return normalized_loss, new_batch_stats
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(loss, new_batch_stats), grad = grad_fn(params)
(loss, grad) = lax.pmean((loss, grad), axis_name='batch')
grad_norm = jnp.sqrt(
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
params)
updated_params = optax.apply_updates(params, updates)
return updated_params, new_batch_stats, new_optimizer_state, jnp.mean(loss), jnp.mean(grad_norm)
def main():
sharded_padded_batch = load_dummy_batch()
# Initing model
config = DeepspeechConfig(use_cudnn_lstm=USE_CUDNN_LSTM)
model_class = Deepspeech(config)
rng = jax.random.PRNGKey(0)
params_rng, dropout_rng = jax.random.split(rng, 2)
model_init_fn = jax.jit(functools.partial(model_class.init, train=False))
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
print('Initializing model.')
vars = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch)
batch_stats, params = vars.pop('params')
print('Initializing optimizer')
replicated_optimizer_state, opt_update_fn = init_optimizer_state(params)
replicated_params = jax_utils.replicate(params)
replicated_batch_stats = jax_utils.replicate(batch_stats)
# Starting Training to measure time:
num_training_steps = 10
grad_clip=5.0
# Defining pmapped update step
bound_train_step = functools.partial(train_step, model_class, opt_update_fn)
pmapped_train_step = jax.pmap(bound_train_step,
axis_name='batch',
in_axes=(0, 0, 0, 0, None, None))
print('Starting training')
print('JAX local device count = ', jax.local_device_count())
for step in range(num_training_steps):
if step == 1:
start_time = time.time()
jax.profiler.start_trace("/experiment_runs/traces/old_layer_bs128_jax044_10steps_new/", create_perfetto_trace=True)
(
replicated_params,
replicated_batch_stats,
replicated_optimizer_state,
loss,
grad_norm) = pmapped_train_step(
replicated_params,
replicated_batch_stats,
replicated_optimizer_state,
sharded_padded_batch,
rng,
grad_clip)
print('{}) loss = {} grad_norm = {}'.format(step, loss[0], grad_norm[0]))
jax.profiler.stop_trace()
end_time = time.time()
print('JAX program execution took %s seconds' % (end_time - start_time))
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment