-
-
Save Goddard/b86c0469c42e1f4c415f37354a5f30db to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import json | |
import os | |
import sys | |
import asyncio | |
import pathlib | |
import websockets | |
from websockets.extensions import permessage_deflate | |
import concurrent.futures | |
import logging | |
import requests | |
import time | |
from vosk import Model, SpkModel, KaldiRecognizer, SetLogLevel | |
from urllib.parse import unquote, urlparse | |
from pathlib import PurePosixPath | |
import uvloop | |
def process_chunk(rec, message): | |
if message == '{"eof" : 1}': | |
return rec.FinalResult(), True | |
elif rec.AcceptWaveform(message): | |
return rec.Result(), False | |
else: | |
return rec.PartialResult(), False | |
def is_json(myjson): | |
try: | |
json.loads(myjson) | |
except ValueError as e: | |
return False | |
return True | |
async def log_latency(websocket, logger): | |
t0 = time.perf_counter() | |
pong_waiter = await websocket.ping() | |
await pong_waiter | |
t1 = time.perf_counter() | |
logger.info("Connection latency: %.3f seconds", t1 - t0) | |
async def recognize(websocket, path): | |
global model | |
global spk_model | |
global args | |
global pool | |
loop = asyncio.get_running_loop() | |
rec = None | |
phrase_list = None | |
sample_rate = args.sample_rate | |
show_words = args.show_words | |
max_alternatives = args.max_alternatives | |
logging.info('Connection from %s', websocket.remote_address); | |
logging.info('Request Path %s', path); | |
extension = PurePosixPath(unquote(urlparse(path).path)).parts[1] | |
previousMessage = "" | |
while True: | |
start = time.time() | |
try: | |
message = await websocket.recv() | |
except Exception as e: | |
logging.info("Websocket Recieve Failed : %s", e) | |
# close socket because this is a likely a client disconnecting | |
await websocket.close() | |
response = rec.FinalResult() | |
print(response) | |
break | |
# Load configuration if provided | |
if isinstance(message, str) and 'config' in message: | |
jobj = json.loads(message)['config'] | |
logging.info("Config %s", jobj) | |
if 'phrase_list' in jobj: | |
phrase_list = jobj['phrase_list'] | |
if 'sample_rate' in jobj: | |
sample_rate = float(jobj['sample_rate']) | |
if 'words' in jobj: | |
show_words = bool(jobj['words']) | |
if 'max_alternatives' in jobj: | |
max_alternatives = int(jobj['max_alternatives']) | |
continue | |
# Create the recognizer, word list is temporary disabled since not every model supports it | |
if not rec: | |
if phrase_list: | |
rec = KaldiRecognizer(model, sample_rate, json.dumps(phrase_list, ensure_ascii=False)) | |
else: | |
rec = KaldiRecognizer(model, sample_rate) | |
rec.SetWords(show_words) | |
rec.SetMaxAlternatives(max_alternatives) | |
if spk_model: | |
rec.SetSpkModel(spk_model) | |
response, stop = await loop.run_in_executor(pool, process_chunk, rec, message) | |
# print(response) | |
if(previousMessage == "" or previousMessage != response): | |
if((isinstance(response, str) and ('partial' in response or 'text' in response)) and is_json(response)): | |
responseJson = json.loads(response) | |
if('partial' in response): | |
responseJsonPartial = responseJson['partial'] | |
if(responseJsonPartial.replace(" ", "") != ""): | |
print(response) | |
end = time.time() | |
print("Time elapsed:", end - start) | |
print('pending:', pool._work_queue.qsize(), 'jobs') | |
print('threads:', len(pool._threads)) | |
print() | |
print("\n") | |
elif('text' in response): | |
responseJsonText = responseJson['text'] | |
if(responseJsonText.replace(" ", "") != ""): | |
print(response) | |
end = time.time() | |
print("Time elapsed:", end - start) | |
print('pending:', pool._work_queue.qsize(), 'jobs') | |
print('threads:', len(pool._threads)) | |
print("\n") | |
previousMessage = response | |
if stop: break | |
async def start(): | |
global model | |
global spk_model | |
global args | |
global pool | |
SetLogLevel(-1) | |
logging.basicConfig(level=logging.INFO) | |
logging.info('Starting server, one moment'); | |
args = type('', (), {})() | |
args.interface = os.environ.get('VOSK_SERVER_INTERFACE', '0.0.0.0') | |
args.port = int(os.environ.get('VOSK_SERVER_PORT', 4000)) | |
args.model_path = os.environ.get('VOSK_MODEL_PATH', 'model') | |
args.spk_model_path = os.environ.get('VOSK_SPK_MODEL_PATH') | |
args.sample_rate = float(os.environ.get('VOSK_SAMPLE_RATE', 8000)) | |
args.max_alternatives = int(os.environ.get('VOSK_ALTERNATIVES', 0)) | |
args.show_words = bool(os.environ.get('VOSK_SHOW_WORDS', True)) | |
if len(sys.argv) > 1: | |
args.model_path = sys.argv[1] | |
model = Model(args.model_path) | |
spk_model = SpkModel(args.spk_model_path) if args.spk_model_path else None | |
pool = concurrent.futures.ThreadPoolExecutor((os.cpu_count() or 8)) | |
async with websockets.serve(recognize, args.interface, args.port, ping_interval=None, ping_timeout=None, compression=None): | |
await asyncio.Future() | |
# , extensions=[ | |
# permessage_deflate.ServerPerMessageDeflateFactory( | |
# server_max_window_bits=11, | |
# client_max_window_bits=11, | |
# compress_settings={"memLevel": 8}) | |
# ] | |
if __name__ == '__main__': | |
if sys.version_info >= (3, 11): | |
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: | |
runner.run(start()) | |
else: | |
uvloop.install() | |
asyncio.run(start()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment