Last active
May 26, 2022 14:59
-
-
Save JustinaPetr/82b81d859836fb3dc08217edff0eb6bb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import uuid | |
from sanic import Blueprint, response | |
from sanic.request import Request | |
from socketio import AsyncServer | |
from typing import Optional, Text, Any, List, Dict, Iterable | |
from rasa.core.channels.channel import InputChannel | |
from rasa.core.channels.channel import UserMessage, OutputChannel | |
import deepspeech | |
from deepspeech import Model | |
import scipy.io.wavfile as wav | |
import os | |
import sys | |
import io | |
import torch | |
import time | |
import numpy as np | |
from collections import OrderedDict | |
import urllib | |
import librosa | |
from TTS.models.tacotron import Tacotron | |
from TTS.layers import * | |
from TTS.utils.data import * | |
from TTS.utils.audio import AudioProcessor | |
from TTS.utils.generic_utils import load_config | |
from TTS.utils.text import text_to_sequence | |
from TTS.utils.synthesis import synthesis | |
from utils.text.symbols import symbols, phonemes | |
from TTS.utils.visual import visualize | |
logger = logging.getLogger(__name__) | |
def load_deepspeech_model(): | |
N_FEATURES = 25 | |
N_CONTEXT = 9 | |
BEAM_WIDTH = 500 | |
LM_ALPHA = 0.75 | |
LM_BETA = 1.85 | |
ds = Model('deepspeech-0.5.1-models/output_graph.pbmm', N_FEATURES, N_CONTEXT, 'deepspeech-0.5.1-models/alphabet.txt', BEAM_WIDTH) | |
return ds | |
def load_tts_model(): | |
MODEL_PATH = './tts_model/best_model.pth.tar' | |
CONFIG_PATH = './tts_model/config.json' | |
CONFIG = load_config(CONFIG_PATH) | |
use_cuda = False | |
num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols) | |
model = Tacotron(num_chars, CONFIG.embedding_size, CONFIG.audio['num_freq'], CONFIG.audio['num_mels'], CONFIG.r, attn_windowing=False) | |
num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols) | |
model = Tacotron(num_chars, CONFIG.embedding_size, CONFIG.audio['num_freq'], CONFIG.audio['num_mels'], CONFIG.r, attn_windowing=False) | |
# load the audio processor | |
# CONFIG.audio["power"] = 1.3 | |
CONFIG.audio["preemphasis"] = 0.97 | |
ap = AudioProcessor(**CONFIG.audio) | |
# load model state | |
if use_cuda: | |
cp = torch.load(MODEL_PATH) | |
else: | |
cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage) | |
# load the model | |
model.load_state_dict(cp['model']) | |
if use_cuda: | |
model.cuda() | |
#model.eval() | |
model.decoder.max_decoder_steps = 1000 | |
return model, ap, MODEL_PATH, CONFIG, use_cuda | |
ds = load_deepspeech_model() | |
model, ap, MODEL_PATH, CONFIG, use_cuda = load_tts_model() | |
class SocketBlueprint(Blueprint): | |
def __init__(self, sio: AsyncServer, socketio_path, *args, **kwargs): | |
self.sio = sio | |
self.socketio_path = socketio_path | |
super(SocketBlueprint, self).__init__(*args, **kwargs) | |
def register(self, app, options): | |
self.sio.attach(app, self.socketio_path) | |
super(SocketBlueprint, self).register(app, options) | |
class SocketIOOutput(OutputChannel): | |
@classmethod | |
def name(cls): | |
return "socketio" | |
def __init__(self, sio, sid, bot_message_evt, message): | |
self.sio = sio | |
self.sid = sid | |
self.bot_message_evt = bot_message_evt | |
self.message = message | |
def tts(self, model, text, CONFIG, use_cuda, ap, OUT_FILE): | |
import numpy as np | |
waveform, alignment, spectrogram, mel_spectrogram, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap) | |
ap.save_wav(waveform, OUT_FILE) | |
wav_norm = waveform * (32767 / max(0.01, np.max(np.abs(waveform)))) | |
return alignment, spectrogram, stop_tokens, wav_norm | |
def tts_predict(self, MODEL_PATH, sentence, CONFIG, use_cuda, OUT_FILE): | |
align, spec, stop_tokens, wav_norm = self.tts(model, sentence, CONFIG, use_cuda, ap, OUT_FILE) | |
return wav_norm | |
async def _send_audio_message(self, socket_id, response, **kwargs: Any): | |
# type: (Text, Any) -> None | |
"""Sends a message to the recipient using the bot event.""" | |
ts = time.time() | |
OUT_FILE = str(ts)+'.wav' | |
link = "http://localhost:8888/"+OUT_FILE | |
wav_norm = self.tts_predict(MODEL_PATH, response['text'], CONFIG, use_cuda, OUT_FILE) | |
await self.sio.emit(self.bot_message_evt, {'text':response['text'], "link":link}, room=socket_id) | |
async def send_text_message(self, recipient_id: Text, message: Text, **kwargs: Any) -> None: | |
"""Send a message through this channel.""" | |
await self._send_audio_message(self.sid, {"text": message}) | |
class SocketIOInput(InputChannel): | |
"""A socket.io input channel.""" | |
@classmethod | |
def name(cls): | |
return "socketio" | |
@classmethod | |
def from_credentials(cls, credentials): | |
credentials = credentials or {} | |
return cls(credentials.get("user_message_evt", "user_uttered"), | |
credentials.get("bot_message_evt", "bot_uttered"), | |
credentials.get("namespace"), | |
credentials.get("session_persistence", False), | |
credentials.get("socketio_path", "/socket.io"), | |
) | |
def __init__(self, | |
user_message_evt: Text = "user_uttered", | |
bot_message_evt: Text = "bot_uttered", | |
namespace: Optional[Text] = None, | |
session_persistence: bool = False, | |
socketio_path: Optional[Text] = '/socket.io' | |
): | |
self.bot_message_evt = bot_message_evt | |
self.session_persistence = session_persistence | |
self.user_message_evt = user_message_evt | |
self.namespace = namespace | |
self.socketio_path = socketio_path | |
def blueprint(self, on_new_message): | |
sio = AsyncServer(async_mode="sanic") | |
socketio_webhook = SocketBlueprint( | |
sio, self.socketio_path, "socketio_webhook", __name__ | |
) | |
@socketio_webhook.route("/", methods=['GET']) | |
async def health(request): | |
return response.json({"status": "ok"}) | |
@sio.on('connect', namespace=self.namespace) | |
async def connect(sid, environ): | |
logger.debug("User {} connected to socketIO endpoint.".format(sid)) | |
print('Connected!') | |
@sio.on('disconnect', namespace=self.namespace) | |
async def disconnect(sid): | |
logger.debug("User {} disconnected from socketIO endpoint." | |
"".format(sid)) | |
@sio.on('session_request', namespace=self.namespace) | |
async def session_request(sid, data): | |
print('This is sessioin request') | |
if data is None: | |
data = {} | |
if 'session_id' not in data or data['session_id'] is None: | |
data['session_id'] = uuid.uuid4().hex | |
await sio.emit("session_confirm", data['session_id'], room=sid) | |
logger.debug("User {} connected to socketIO endpoint." | |
"".format(sid)) | |
@sio.on('user_uttered', namespace=self.namespace) | |
async def handle_message(sid, data): | |
output_channel = SocketIOOutput(sio, sid, self.bot_message_evt, data['message']) | |
if data['message'] == "/get_started": | |
message = data['message'] | |
else: | |
##receive audio | |
received_file = 'output_'+sid+'.wav' | |
urllib.request.urlretrieve(data['message'], received_file) | |
path = os.path.dirname(__file__) | |
fs, audio = wav.read("output_{0}.wav".format(sid)) | |
message = ds.stt(audio, fs) | |
await sio.emit(self.user_message_evt, {"text":message}, room=sid) | |
message_rasa = UserMessage(message, output_channel, sid, | |
input_channel=self.name()) | |
await on_new_message(message_rasa) | |
return socketio_webhook | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment