Skip to content

Instantly share code, notes, and snippets.

@heimoshuiyu
Last active September 2, 2023 08:53
Show Gist options
  • Save heimoshuiyu/2e27ea2b39ce6ad162bb454caadd61f6 to your computer and use it in GitHub Desktop.
Save heimoshuiyu/2e27ea2b39ce6ad162bb454caadd61f6 to your computer and use it in GitHub Desktop.
Self-hosted OpenAI Whisper model with fastapi, support OpenAI API Format and alumae/ruby-pocketsphinx-server
import wave
import hashlib
import argparse
from datetime import datetime
import os
from typing import Any
from fastapi import File, UploadFile, Form, FastAPI, Request
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
from src.whisper_ctranslate2.writers import format_timestamp
import opencc
ccc = opencc.OpenCC("t2s.json")
app = FastAPI()
# allow all cors
from fastapi.middleware.cors import CORSMiddleware
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the allowed file extensions
ALLOWED_EXTENSIONS = {
"mp3",
"mp4",
"mpeg",
"mpga",
"m4a",
"wav",
"webm",
"3gp",
"flac",
"ogg",
"mkv",
}
def allowed_file(filename: str | None):
if filename is None:
return False
if filename.split(".")[-1] not in ALLOWED_EXTENSIONS:
return False
return True
def generate_tsv(result: dict[str, list[Any]]):
tsv = "start\tend\ttext\n"
for i, segment in enumerate(result["segments"]):
start_time = str(round(1000 * segment["start"]))
end_time = str(round(1000 * segment["end"]))
text = segment["text"]
tsv += f"{start_time}\t{end_time}\t{text}\n"
return tsv
def generate_srt(result: dict[str, list[Any]]):
srt = ""
for i, segment in enumerate(result["segments"], start=1):
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"]
srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
return srt
def generate_vtt(result: dict[str, list[Any]]):
vtt = "WEBVTT\n\n"
for segment in result["segments"]:
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"]
vtt += f"{start_time} --> {end_time}\n{text}\n\n"
return vtt
print("Loading model...")
transcriber = Transcribe(
model_path="large-v2",
device="auto",
device_index=0,
compute_type="default",
threads=1,
cache_directory="",
local_files_only=False,
)
print("Model loaded!")
def get_options(*, initial_prompt=""):
options = TranscriptionOptions(
beam_size=5,
best_of=5,
patience=1.0,
length_penalty=1.0,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
compression_ratio_threshold=2.4,
condition_on_previous_text=True,
temperature=[0.0, 1.0 + 1e-6, 0.2],
suppress_tokens=[-1],
word_timestamps=True,
print_colors=False,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,,!!??::”)]}、",
vad_filter=False,
vad_threshold=None,
vad_min_speech_duration_ms=None,
vad_max_speech_duration_s=None,
vad_min_silence_duration_ms=None,
initial_prompt=initial_prompt,
)
return options
@app.post("/android")
async def translateapi(request: Request, task: str = "transcribe"):
content_type = request.headers.get("Content-Type", "")
print("task", task)
print("downloading request file", content_type)
splited = [i.strip() for i in content_type.split(",") if "=" in i]
info = {k: v for k, v in (i.split("=") for i in splited)}
print(info)
channels = int(info.get("channels", "1"))
rate = int(info.get("rate", "16000"))
body = await request.body()
md5 = hashlib.md5(body).hexdigest()
filename = datetime.now().strftime("%Y%m%d-%H%M%S") + "." + md5 + ".wav"
# save the file to a temporary location
file_path = os.path.join("./cache", "android", filename)
with wave.open(file_path, "wb") as buffer:
buffer.setnchannels(channels)
buffer.setsampwidth(2)
buffer.setframerate(rate)
buffer.writeframes(body)
options = get_options()
result = transcriber.inference(
audio=file_path,
task=task,
language="",
verbose=False,
live=False,
options=options,
)
text = result.get("text", "")
text = ccc.convert(text)
print("result", text)
return {
"status": 0,
"hypotheses": [{"utterance": text}],
"id": md5,
}
@app.post("/v1/audio/transcriptions")
async def transcription(
file: UploadFile = File(...),
prompt: str = Form(""),
response_type: str = Form("json"),
):
"""Transcription endpoint
User upload audio file in multipart/form-data format and receive transcription in response
"""
# check if the file is allowed
if not allowed_file(file.filename):
return {"error": "Invalid file format"}
# timestamp as filename, keep original extension
assert file.filename is not None
filename = (
datetime.now().strftime("%Y%m%d-%H%M%S") + "." + file.filename.split(".")[-1]
)
# save the file to a temporary location
file_path = os.path.join("./cache", filename)
with open(file_path, "wb") as buffer:
buffer.write(file.file.read())
# Define the transcription options
options = get_options(initial_prompt=prompt)
result: Any = transcriber.inference(
audio=file_path,
task="transcribe",
language="",
verbose=False,
live=False,
options=options,
)
if response_type == 'text':
return result["text"].strip()
elif response_type == "json":
return result
elif response_type == "tsv":
return generate_tsv(result)
elif response_type == "srt":
return generate_srt(result)
elif response_type == "vtt":
return generate_vtt(result)
return {"error": "Invalid response_type"}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int)
args = parser.parse_args()
import uvicorn
uvicorn.run(app, host=args.host, port=args.port)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment