Skip to content

Instantly share code, notes, and snippets.

@fauxneticien
Last active September 8, 2023 19:57
Show Gist options
  • Save fauxneticien/d57dbe5fc9d7ec38a8e35920d03cdb92 to your computer and use it in GitHub Desktop.
Save fauxneticien/d57dbe5fc9d7ec38a8e35920d03cdb92 to your computer and use it in GitHub Desktop.
Lhotse DDP test
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import os
import torchaudio
from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import IterableDatasetWrapper, SpeechSynthesisDataset, DynamicBucketingSampler, OnTheFlyFeatures, make_worker_init_fn
from lhotse.recipes import download_librispeech, prepare_librispeech
from tqdm import tqdm
# Download and set up data on first run
if not os.path.exists("LibriSpeech/dev-clean-2"):
download_librispeech(dataset_parts="dev-clean-2")
libri = prepare_librispeech(corpus_dir="LibriSpeech")
CutSet.from_manifests(**libri['dev-clean-2']).to_jsonl("LibriSpeech/dev-clean-2.jsonl")
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
def ddp_setup(rank, world_size):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
gpu_id: int,
loss
) -> None:
self.gpu_id = gpu_id
self.model = model.to(gpu_id)
self.train_data = train_data
self.optimizer = optimizer
self.loss = loss
self.model = DDP(model, device_ids=[gpu_id])
def train(self, max_iters: int):
epoch = 0
iterator = iter(self.train_data)
for global_step in tqdm(range(max_iters), disable=self.gpu_id != 0):
try:
worker_info, cut_ids, batch = next(iterator)
except StopIteration:
epoch += 1
self.train_data.sampler.set_epoch(epoch)
iterator = iter(self.train_data)
worker_info, cut_ids, batch = next(iterator)
# Uncomment to display cut_ids to make sure GPUs are getting different data across epochs/GPUs/processes/workers
# print(f"GPU: {self.gpu_id}; Worker: {worker_info.id + 1}/{worker_info.num_workers}; Iteration: {global_step}, Cuts: {cut_ids}")
log_probs = self.model(batch['features'].to(self.gpu_id))
loss = self.loss(
log_probs.transpose(0, 1),
batch["tokens"].to(self.gpu_id),
batch["features_lens"].to(self.gpu_id),
batch["tokens_lens"].to(self.gpu_id)
)
loss.backward()
self.optimizer.step()
class ASRDataset(SpeechSynthesisDataset):
# Hijacking SpeechSynthesisDataset (has token collator, etc.) instead of K2SpeechRecognitionDataset for demo
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, cuts: CutSet):
worker_info=torch.utils.data.get_worker_info()
cut_ids = ", ".join([ c.id for c in cuts ])
batch=super(ASRDataset, self).__getitem__(cuts)
return worker_info, cut_ids, batch
def main(rank: int, world_size: int, max_iters: int):
ddp_setup(rank, world_size)
cuts = CutSet.from_jsonl_lazy("LibriSpeech/dev-clean-2.jsonl")
sampler = DynamicBucketingSampler(
cuts,
shuffle=True,
max_duration=60,
drop_last=True,
num_buckets=10,
rank=rank,
world_size=world_size
)
dataset = ASRDataset(cuts, feature_input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))))
train_data = DataLoader(
dataset,
sampler=sampler,
batch_size=None,
num_workers=8,
worker_init_fn=make_worker_init_fn(
rank=rank,
world_size=world_size
),
persistent_workers=True,
pin_memory=True
)
tokenizer = dataset.token_collater
model = torchaudio.models.DeepSpeech(n_feature=80, n_class=len(list(tokenizer.idx2token)))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss = torch.nn.CTCLoss(blank=list(tokenizer.idx2token).index('<pad>'), reduction="mean", zero_infinity=True)
trainer = Trainer(model, train_data, optimizer, rank, loss)
trainer.train(max_iters)
destroy_process_group()
if __name__ == "__main__":
# import argparse
# parser = argparse.ArgumentParser(description='simple distributed training job')
# parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
# parser.add_argument('save_every', type=int, help='How often to save a snapshot')
# parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
# args = parser.parse_args()
world_size = torch.cuda.device_count()
# mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)
mp.spawn(main, args=(world_size, 100), nprocs=world_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment