Last active
February 27, 2024 06:55
-
-
Save hotbaby/fccaf96ac1d0a76cdcfcaf2470326fca to your computer and use it in GitHub Desktop.
whisper多卡分布式推理
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf8 | |
import json | |
import torch | |
import argparse | |
import whisper | |
from whisper import Whisper, ModelDimensions | |
from torch.utils.data import Dataset, DataLoader | |
from lightning import Trainer | |
from lightning import LightningModule | |
from lightning import Callback | |
from typing import Dict | |
from filelock import FileLock | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data_path", type=str, required=True) | |
parser.add_argument("--model_path", type=str, required=True) | |
parser.add_argument("--result_path", type=str, required=True) | |
parser.add_argument("--device_num", type=int, default=1, help="concurrency device number, default is 1") | |
args = parser.parse_args() | |
class AishellDataset(Dataset): | |
def __init__(self, data_path) -> None: | |
self.data_list = [] | |
with open(data_path) as f: | |
for line in f: | |
self.data_list.append(json.loads(line)) | |
def __getitem__(self, idx): | |
return self.data_list[idx] | |
def __len__(self): | |
return len(self.data_list) | |
class ResultCallback(Callback): | |
def __init__(self, result_path) -> None: | |
self.result_path = result_path | |
def on_predict_batch_end( | |
self, | |
trainer: Trainer, | |
pl_module: LightningModule, | |
outputs: Dict, | |
batch: Dict, | |
batch_idx: int, | |
dataloader_idx: int = 0 | |
) -> None: | |
# 使用文件锁,防止串行 | |
with FileLock("/tmp/whisper.lock"): | |
with open(self.result_path, "a+") as f: | |
f.write(json.dumps(outputs, ensure_ascii=False) + "\n") | |
class WhisperWrapModel(LightningModule): | |
def __init__(self, model_path): | |
super().__init__() | |
checkpoint = torch.load(model_path) | |
dims = ModelDimensions(**checkpoint["dims"]) | |
model = Whisper(dims) | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
# DDP通信不支持sparse tensor,将 alignment_heads buffer 修改成dense tensor, | |
# 否则,DDP distributed会报"No support for sparse tensors"错误。 | |
alignment_heads_dense = model.get_buffer("alignment_heads").to_dense() | |
model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False) | |
self.model: whisper.Whisper = model | |
def predict_step(self, batch, batch_idx): | |
wav_path = batch["wav"][0] | |
transcribe = self.model.transcribe(wav_path) | |
result = { | |
"key": batch["key"][0], | |
"wav": batch["wav"][0], | |
"ref": batch["txt"][0], | |
"hyp": transcribe["text"], | |
} | |
return result | |
if __name__ == "__main__": | |
# data_path = "/home/rd/wenet/examples/aishell/s0/data/test/data.list" | |
dataset = AishellDataset(data_path=args.data_path) | |
dataloader = DataLoader(dataset=dataset[:10], batch_size=1) # batch_size must be 1 | |
# result_path = "result.jsonl" | |
with open(args.result_path, "a+") as f: | |
f.truncate(0) | |
# model_path = "/data/models/whisper/pytorch_model/large-v3.pt" | |
lightning_model = WhisperWrapModel(args.model_path) | |
trainer = Trainer(accelerator="cuda", devices=args.device_num, callbacks=[ResultCallback(args.result_path)]) | |
trainer.predict(model=lightning_model, dataloaders=dataloader) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Q: whisper 多卡推理报错,No support for sparse tensors
A:
DDP NCCL 通信不支持sparse tensor。
Whisper模型内部
alignment_heads
tensor是sparse tensor,改成dense tensor后可以解决这个问题。