Skip to content

Instantly share code, notes, and snippets.

@aleksandr-smechov
Created November 17, 2023 05:01
Show Gist options
  • Save aleksandr-smechov/f9a223ff62521c1644469b3505308d45 to your computer and use it in GitHub Desktop.
Save aleksandr-smechov/f9a223ff62521c1644469b3505308d45 to your computer and use it in GitHub Desktop.
Server-side distil-whisper streaming code
import asyncio
from typing import List
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
vad_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=True,
onnx=True)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-medium.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
transcriber = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=64,
torch_dtype=torch_dtype,
device=device,
)
app = FastAPI()
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
async def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_text(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
await asyncio.gather(*(connection.send_text(message) for connection in self.active_connections))
manager = ConnectionManager()
@app.websocket("/ws/transcribe")
async def websocket_transcribe(websocket: WebSocket):
await manager.connect(websocket)
try:
async def audio_generator():
while True:
audio_data = await websocket.receive_bytes()
np_data = np.frombuffer(audio_data, dtype=np.float32)
speech_prob = vad_model(torch.tensor(np_data), 16000).item()
if speech_prob > 0.2:
yield np_data
audio_gen = audio_generator()
async for audio_chunk in audio_gen:
transcription = transcriber(audio_chunk, generate_kwargs={"max_new_tokens": 128})
await manager.send_text(transcription["text"], websocket)
except WebSocketDisconnect:
await manager.disconnect(websocket)
except Exception as e:
print(e)
await websocket.close(code=1001)
await manager.disconnect(websocket)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment