Skip to content

Instantly share code, notes, and snippets.

@Fhrozen
Created January 18, 2024 02:33
Show Gist options
  • Save Fhrozen/5cc0366fb40fc08f7358d5287ed69435 to your computer and use it in GitHub Desktop.
Save Fhrozen/5cc0366fb40fc08f7358d5287ed69435 to your computer and use it in GitHub Desktop.
Script for tracing TDNN from speechbrain
import os
import numpy as np
import torchaudio
import torch
from torch import nn
from speechbrain.inference.classifiers import EncoderClassifier
from matplotlib import pyplot as plt
from speechbrain.lobes.models.Xvector import Xvector
from speechbrain.lobes.features import Fbank
from speechbrain.processing.features import InputNormalization
class Extractor(nn.Module):
model_dict = [
"mean_var_norm",
"compute_features",
"embedding_model",
"mean_var_norm_emb",
]
def __init__(self, model_path, n_mels=24, device="cpu"):
super().__init__()
self.device = device
self.compute_features = Fbank(n_mels=n_mels)
self.mean_var_norm = InputNormalization(norm_type="sentence", std_norm=False)
self.embedding_model = Xvector(
in_channels = n_mels,
activation = torch.nn.LeakyReLU,
tdnn_blocks = 5,
tdnn_channels = [512, 512, 512, 512, 1500],
tdnn_kernel_sizes = [5, 3, 3, 1, 1],
tdnn_dilations = [1, 2, 3, 1, 1],
lin_neurons = 512,
)
self.mean_var_norm_emb = InputNormalization(norm_type="global", std_norm=False)
for mod_name in self.model_dict:
filename = os.path.join(model_path, f"{mod_name}.ckpt")
module = getattr(self, mod_name)
if os.path.exists(filename):
if hasattr(module, "_load"):
print(f"Load: {filename}")
module._load(filename)
else:
print(f"Load State Dict: {filename}")
module.load_state_dict(torch.load(filename))
module.to(self.device)
# self.mean_var_norm_emb._load(os.path.join(model_path, "mean_var_norm_emb.ckpt"))
# self.embedding_model.load_state_dict(
# torch.load(os.path.join(model_path, "embedding_model.ckpt"))
# )
@torch.no_grad()
def forward(self, wavs, wav_lens = None, normalize=False):
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
# Computing features and embeddings
feats = self.compute_features(wavs)
feats = self.mean_var_norm(feats, wav_lens)
embeddings = self.embedding_model(feats, wav_lens)
if normalize:
embeddings = self.mean_var_norm_emb(
embeddings, torch.ones(embeddings.shape[0], device=self.device)
)
return embeddings
MODEL_PATH = "pretrained_models/spkrec-xvect-voxceleb"
classifier = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-xvect-voxceleb",
savedir=MODEL_PATH
)
signal, fs = torchaudio.load('/export/corpus01/LibriSpeech/dev-clean/1272/128104/1272-128104-0000.flac')
embeddings_class = classifier.encode_batch(signal).cpu().squeeze()
device = "cuda"
extractor = Extractor(MODEL_PATH, device=device)
for k, p in extractor.named_parameters():
p.requires_grad = False
extractor.eval()
embeddings_x = extractor(signal).cpu().squeeze()
# Similarity Evaluation
cos = nn.CosineSimilarity(dim=0, eps=1e-6)
output = cos(embeddings_x, embeddings_class)
diff = embeddings_class - embeddings_x
# print(embeddings_class[:10], embeddings_x[:10])
print(output, diff.abs().sum())
# Tracing
traced_model = torch.jit.trace(extractor, signal)
torch.jit.save(traced_model, f"model_{device}.pt")
embeddings_t = traced_model(signal).squeeze()
output1 = cos(embeddings_class.to(device), embeddings_t)
output2 = cos(embeddings_x.to(device), embeddings_t)
print(embeddings_t.shape, output1, output2)
model = torch.jit.load(f"model_{device}.pt")
emb_m = model(signal).squeeze()
print(model.code)
print(cos(embeddings_x.to(device), emb_m))
print(emb_m)
# Also in the speechbrain package, a minor code replace:
# speechbrain.nnet.pooling.py L296:
# commented: #actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
# added: actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment