Skip to content

Instantly share code, notes, and snippets.

@hotbaby
Last active February 27, 2024 06:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hotbaby/fccaf96ac1d0a76cdcfcaf2470326fca to your computer and use it in GitHub Desktop.
Save hotbaby/fccaf96ac1d0a76cdcfcaf2470326fca to your computer and use it in GitHub Desktop.
whisper多卡分布式推理
# encoding: utf8
import json
import torch
import argparse
import whisper
from whisper import Whisper, ModelDimensions
from torch.utils.data import Dataset, DataLoader
from lightning import Trainer
from lightning import LightningModule
from lightning import Callback
from typing import Dict
from filelock import FileLock
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--result_path", type=str, required=True)
parser.add_argument("--device_num", type=int, default=1, help="concurrency device number, default is 1")
args = parser.parse_args()
class AishellDataset(Dataset):
def __init__(self, data_path) -> None:
self.data_list = []
with open(data_path) as f:
for line in f:
self.data_list.append(json.loads(line))
def __getitem__(self, idx):
return self.data_list[idx]
def __len__(self):
return len(self.data_list)
class ResultCallback(Callback):
def __init__(self, result_path) -> None:
self.result_path = result_path
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Dict,
batch: Dict,
batch_idx: int,
dataloader_idx: int = 0
) -> None:
# 使用文件锁,防止串行
with FileLock("/tmp/whisper.lock"):
with open(self.result_path, "a+") as f:
f.write(json.dumps(outputs, ensure_ascii=False) + "\n")
class WhisperWrapModel(LightningModule):
def __init__(self, model_path):
super().__init__()
checkpoint = torch.load(model_path)
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
# DDP通信不支持sparse tensor,将 alignment_heads buffer 修改成dense tensor,
# 否则,DDP distributed会报"No support for sparse tensors"错误。
alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
self.model: whisper.Whisper = model
def predict_step(self, batch, batch_idx):
wav_path = batch["wav"][0]
transcribe = self.model.transcribe(wav_path)
result = {
"key": batch["key"][0],
"wav": batch["wav"][0],
"ref": batch["txt"][0],
"hyp": transcribe["text"],
}
return result
if __name__ == "__main__":
# data_path = "/home/rd/wenet/examples/aishell/s0/data/test/data.list"
dataset = AishellDataset(data_path=args.data_path)
dataloader = DataLoader(dataset=dataset[:10], batch_size=1) # batch_size must be 1
# result_path = "result.jsonl"
with open(args.result_path, "a+") as f:
f.truncate(0)
# model_path = "/data/models/whisper/pytorch_model/large-v3.pt"
lightning_model = WhisperWrapModel(args.model_path)
trainer = Trainer(accelerator="cuda", devices=args.device_num, callbacks=[ResultCallback(args.result_path)])
trainer.predict(model=lightning_model, dataloaders=dataloader)
@hotbaby
Copy link
Author

hotbaby commented Feb 27, 2024

Q: whisper 多卡推理报错,No support for sparse tensors

A:
DDP NCCL 通信不支持sparse tensor。
Whisper模型内部alignment_heads tensor是sparse tensor,改成dense tensor后可以解决这个问题。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment