Skip to content

Instantly share code, notes, and snippets.

@piraka9011
Created July 23, 2022 18:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save piraka9011/82f46c8479e0b90dbca19cac68e4759d to your computer and use it in GitHub Desktop.
Save piraka9011/82f46c8479e0b90dbca19cac68e4759d to your computer and use it in GitHub Desktop.
Convert NeMo CitriNet to iOS
from torch.quantization import quantize_dynamic
from torch.utils.mobile_optimizer import optimize_for_mobile
from nemo.collections.asr.models import EncDecCTCModelBPE
# from nemo.collections.asr.parts.preprocessing import FilterbankFeatures
from omegaconf import OmegaConf
import torch
import torchaudio
import math
import random
from typing import Dict, Union
# import librosa
import torch
import torch.nn as nn
import torchaudio.functional as F
from nemo.utils import logging
@torch.jit.script
def normalize_batch(x: torch.Tensor, seq_len: torch.Tensor, normalize_type: str):
eps = 1e-5
if normalize_type == "per_feature":
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
if x[i, :, : seq_len[i]].shape[1] == 1:
raise ValueError(
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
"in torch.std() returning nan"
)
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)
# make sure x_std is not zero
x_std += eps
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
elif normalize_type == "all_features":
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
x_std[i] = x[i, :, : seq_len[i].item()].std()
# make sure x_std is not zero
x_std += eps
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1)
# elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
# x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
# x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
# return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2)
else:
return x
@torch.jit.script
def splice_frames(x: torch.Tensor, frame_splicing: int):
""" Stacks frames together across feature dim
input is batch_size, feature_dim, num_frames
output is batch_size, feature_dim*frame_splicing, num_frames
"""
seq = [x]
for n in range(1, frame_splicing):
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
return torch.cat(seq, dim=1)
class FilterbankFeatures(nn.Module):
"""Featurizer that converts wavs to Mel Spectrograms.
See AudioToMelSpectrogramPreprocessor for args.
"""
def __init__(
self,
sample_rate=16000,
n_window_size=320,
n_window_stride=160,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
nfilt=64,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2 ** -24,
pad_to=16,
max_duration=16.7,
frame_splicing=1,
exact_pad=False,
pad_value=0,
mag_power=2.0,
use_grads=False,
constant = 1e-5,
):
super().__init__()
if exact_pad and n_window_stride % 2 == 1:
raise NotImplementedError(
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
)
self.log_zero_guard_value = log_zero_guard_value
if (
n_window_size is None
or n_window_stride is None
or not isinstance(n_window_size, int)
or not isinstance(n_window_stride, int)
or n_window_size <= 0
or n_window_stride <= 0
):
raise ValueError(
f"{self} got an invalid value for either n_window_size or "
f"n_window_stride. Both must be positive ints."
)
logging.info(f"PADDING: {pad_to}")
self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
if exact_pad:
logging.info("STFT using exact pad")
torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'none': None,
}
window_fn = torch_windows.get(window, None)
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
self.register_buffer("window", window_tensor)
self.exact_pad = exact_pad
self._constant = constant
self.normalize = normalize
self.log = log
self.frame_splicing = frame_splicing
self.nfilt = nfilt
self.preemph = preemph
self.pad_to = pad_to
highfreq = highfreq or sample_rate / 2
# filterbanks = torch.tensor(
# librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq),
# dtype=torch.float,
# ).unsqueeze(0)
filterbanks = F.melscale_fbanks(
sample_rate=sample_rate, n_freqs=int(self.n_fft // 2 + 1), n_mels=nfilt, f_min=lowfreq, f_max=highfreq,
).T.unsqueeze(0)
self.register_buffer("fb", filterbanks)
# Calculate maximum sequence length
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
self.max_length = max_length + max_pad
self.pad_value = pad_value
self.mag_power = mag_power
# We want to avoid taking the log of zero
# There are two options: either adding or clamping to a small value
if log_zero_guard_type not in ["add", "clamp"]:
raise ValueError(
f"{self} received {log_zero_guard_type} for the "
f"log_zero_guard_type parameter. It must be either 'add' or "
f"'clamp'."
)
# log_zero_guard_value is the the small we want to use, we support
# an actual number, or "tiny", or "eps"
self.log_zero_guard_type = log_zero_guard_type
logging.debug(f"sr: {sample_rate}")
logging.debug(f"n_fft: {self.n_fft}")
logging.debug(f"win_length: {self.win_length}")
logging.debug(f"hop_length: {self.hop_length}")
logging.debug(f"n_mels: {nfilt}")
logging.debug(f"fmin: {lowfreq}")
logging.debug(f"fmax: {highfreq}")
logging.debug(f"using grads: {use_grads}")
def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str):
if self.log_zero_guard_value == "tiny":
return torch.finfo(x.dtype).tiny
elif self.log_zero_guard_value == "eps":
return torch.finfo(x.dtype).eps
else:
raise ValueError(
f"{self} received {self.log_zero_guard_value} for the "
f"log_zero_guard_type parameter. It must be either a "
f"number, 'tiny', or 'eps'"
)
else:
return self.log_zero_guard_value
def get_seq_len(self, seq_len):
# Assuming that center is True is stft_pad_amount = 0
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1
return seq_len.to(dtype=torch.long)
@property
def filter_banks(self):
return self.fb
def normalize_batch(self, x: torch.Tensor, seq_len: torch.Tensor, normalize_type: str):
if normalize_type == "per_feature":
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
if x[i, :, : seq_len[i]].shape[1] == 1:
raise ValueError(
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
"in torch.std() returning nan"
)
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)
# make sure x_std is not zero
# x_std += CONSTANT
x_std += self._constant
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
elif normalize_type == "all_features":
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
x_std[i] = x[i, :, : seq_len[i].item()].std()
# make sure x_std is not zero
# x_std += CONSTANT
x_std += self._constant
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1)
# elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
# x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
# x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
# return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2)
else:
return x
def splice_frames(self, x: torch.Tensor, frame_splicing: int):
""" Stacks frames together across feature dim
input is batch_size, feature_dim, num_frames
output is batch_size, feature_dim*frame_splicing, num_frames
"""
seq = [x]
for n in range(1, frame_splicing):
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
return torch.cat(seq, dim=1)
def forward(self, x, seq_len):
seq_len = self.get_seq_len(seq_len.float())
if self.stft_pad_amount is not None:
x = torch.nn.functional.pad(
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
).squeeze(1)
# do preemphasis
if self.preemph is not None:
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
# disable autocast to get full range of stft values
# with torch.cuda.amp.autocast(enabled=False):
# x = self.stft(x)
## MSIS: the above autocast messes up the mobile export.
x = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if self.exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=False,
)
# torch returns real, imag; so convert to magnitude
# guard is needed for sqrt if grads are passed through
# guard = 0 if not self.use_grads else CONSTANT
if x.dtype in [torch.cfloat, torch.cdouble]:
x = torch.view_as_real(x)
x = torch.sqrt(x.pow(2).sum(-1))
# get power spectrum
if self.mag_power != 1.0:
x = x.pow(self.mag_power)
# dot with filterbank energies
x = torch.matmul(self.fb.to(x.dtype), x)
# log features if required
if self.log:
if self.log_zero_guard_type == "add":
x = torch.log(x + self.log_zero_guard_value_fn(x))
elif self.log_zero_guard_type == "clamp":
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
# else:
# raise ValueError("log_zero_guard_type was not understood")
# frame splicing if required
if self.frame_splicing > 1:
# x = splice_frames(x, self.frame_splicing)
x = self.splice_frames(x, self.frame_splicing)
# normalize if required
if self.normalize:
# x = normalize_batch(x, seq_len, normalize_type=self.normalize)
x = self.normalize_batch(x, seq_len, normalize_type=self.normalize)
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
max_len = x.size(-1)
mask = torch.arange(max_len).to(x.device)
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
del mask
pad_to = self.pad_to
# if pad_to == "max":
# x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
# elif pad_to > 0:
# pad_amt = x.size(-1) % pad_to
# if pad_amt != 0:
# x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
if pad_to > 0:
pad_amt = x.size(-1) % pad_to
if pad_amt != 0:
# x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
x = torch.nn.functional.pad(x, (0, pad_to - pad_amt), value=float(self.pad_value))
return x, seq_len
class ModelWrapper2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
mask = torch.randn((1, 80, 100))
x.masked_fill_(mask, 0.0)
return x
class ModelWrapper(torch.nn.Module):
def __init__(self, exported_model):
super().__init__()
self.encoder = exported_model
self.sample_rate = 16000
self.featurizer = FilterbankFeatures(
sample_rate=self.sample_rate,
nfilt=80,
n_fft=512,
pad_to=60,
normalize='per_feature',
n_window_size=400,
n_window_stride=160,
window='hann',
frame_splicing=1,
)
def forward(self, waveform: torch.Tensor, sample_rate: int):
if waveform.size(0) != 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != self.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate)
length = torch.tensor([waveform.shape[1]])
waveform, length = self.featurizer(waveform, length)
hypothesis = self.encoder(waveform, length)
return hypothesis
if __name__ == "__main__":
model = EncDecCTCModelBPE.from_pretrained('stt_en_citrinet_256', map_location='cpu')
model = model.eval()
model.export(f"/tmp/{model._get_name()}.ts", check_trace=True)
scripted_encoder = torch.jit.load(f"/tmp/{model._get_name()}.ts")
wrapped_model = ModelWrapper(scripted_encoder)
scripted_model = torch.jit.script(wrapped_model)
quantized_model = quantize_dynamic(
scripted_model,
qconfig_spec={torch.nn.Linear},
dtype=torch.qint8
)
optimized_model = optimize_for_mobile(quantized_model, backend="metal")
optimized_model._save_for_lite_interpreter("/tmp/mymodel-metal.ts")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment