Skip to content

Instantly share code, notes, and snippets.

@Goddard
Last active October 4, 2022 16:12
Show Gist options
  • Save Goddard/b86c0469c42e1f4c415f37354a5f30db to your computer and use it in GitHub Desktop.
Save Goddard/b86c0469c42e1f4c415f37354a5f30db to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import json
import os
import sys
import asyncio
import pathlib
import websockets
from websockets.extensions import permessage_deflate
import concurrent.futures
import logging
import requests
import time
from vosk import Model, SpkModel, KaldiRecognizer, SetLogLevel
from urllib.parse import unquote, urlparse
from pathlib import PurePosixPath
import uvloop
def process_chunk(rec, message):
if message == '{"eof" : 1}':
return rec.FinalResult(), True
elif rec.AcceptWaveform(message):
return rec.Result(), False
else:
return rec.PartialResult(), False
def is_json(myjson):
try:
json.loads(myjson)
except ValueError as e:
return False
return True
async def log_latency(websocket, logger):
t0 = time.perf_counter()
pong_waiter = await websocket.ping()
await pong_waiter
t1 = time.perf_counter()
logger.info("Connection latency: %.3f seconds", t1 - t0)
async def recognize(websocket, path):
global model
global spk_model
global args
global pool
loop = asyncio.get_running_loop()
rec = None
phrase_list = None
sample_rate = args.sample_rate
show_words = args.show_words
max_alternatives = args.max_alternatives
logging.info('Connection from %s', websocket.remote_address);
logging.info('Request Path %s', path);
extension = PurePosixPath(unquote(urlparse(path).path)).parts[1]
previousMessage = ""
while True:
start = time.time()
try:
message = await websocket.recv()
except Exception as e:
logging.info("Websocket Recieve Failed : %s", e)
# close socket because this is a likely a client disconnecting
await websocket.close()
response = rec.FinalResult()
print(response)
break
# Load configuration if provided
if isinstance(message, str) and 'config' in message:
jobj = json.loads(message)['config']
logging.info("Config %s", jobj)
if 'phrase_list' in jobj:
phrase_list = jobj['phrase_list']
if 'sample_rate' in jobj:
sample_rate = float(jobj['sample_rate'])
if 'words' in jobj:
show_words = bool(jobj['words'])
if 'max_alternatives' in jobj:
max_alternatives = int(jobj['max_alternatives'])
continue
# Create the recognizer, word list is temporary disabled since not every model supports it
if not rec:
if phrase_list:
rec = KaldiRecognizer(model, sample_rate, json.dumps(phrase_list, ensure_ascii=False))
else:
rec = KaldiRecognizer(model, sample_rate)
rec.SetWords(show_words)
rec.SetMaxAlternatives(max_alternatives)
if spk_model:
rec.SetSpkModel(spk_model)
response, stop = await loop.run_in_executor(pool, process_chunk, rec, message)
# print(response)
if(previousMessage == "" or previousMessage != response):
if((isinstance(response, str) and ('partial' in response or 'text' in response)) and is_json(response)):
responseJson = json.loads(response)
if('partial' in response):
responseJsonPartial = responseJson['partial']
if(responseJsonPartial.replace(" ", "") != ""):
print(response)
end = time.time()
print("Time elapsed:", end - start)
print('pending:', pool._work_queue.qsize(), 'jobs')
print('threads:', len(pool._threads))
print()
print("\n")
elif('text' in response):
responseJsonText = responseJson['text']
if(responseJsonText.replace(" ", "") != ""):
print(response)
end = time.time()
print("Time elapsed:", end - start)
print('pending:', pool._work_queue.qsize(), 'jobs')
print('threads:', len(pool._threads))
print("\n")
previousMessage = response
if stop: break
async def start():
global model
global spk_model
global args
global pool
SetLogLevel(-1)
logging.basicConfig(level=logging.INFO)
logging.info('Starting server, one moment');
args = type('', (), {})()
args.interface = os.environ.get('VOSK_SERVER_INTERFACE', '0.0.0.0')
args.port = int(os.environ.get('VOSK_SERVER_PORT', 4000))
args.model_path = os.environ.get('VOSK_MODEL_PATH', 'model')
args.spk_model_path = os.environ.get('VOSK_SPK_MODEL_PATH')
args.sample_rate = float(os.environ.get('VOSK_SAMPLE_RATE', 8000))
args.max_alternatives = int(os.environ.get('VOSK_ALTERNATIVES', 0))
args.show_words = bool(os.environ.get('VOSK_SHOW_WORDS', True))
if len(sys.argv) > 1:
args.model_path = sys.argv[1]
model = Model(args.model_path)
spk_model = SpkModel(args.spk_model_path) if args.spk_model_path else None
pool = concurrent.futures.ThreadPoolExecutor((os.cpu_count() or 8))
async with websockets.serve(recognize, args.interface, args.port, ping_interval=None, ping_timeout=None, compression=None):
await asyncio.Future()
# , extensions=[
# permessage_deflate.ServerPerMessageDeflateFactory(
# server_max_window_bits=11,
# client_max_window_bits=11,
# compress_settings={"memLevel": 8})
# ]
if __name__ == '__main__':
if sys.version_info >= (3, 11):
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
runner.run(start())
else:
uvloop.install()
asyncio.run(start())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment