-
-
Save phixion/8719bfe5a85b8109b5b2330cc01690a1 to your computer and use it in GitHub Desktop.
Replace tts google.py for home assistant with code that uses wavenet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Support for the google speech service. | |
For more details about this component, please refer to the documentation at | |
https://home-assistant.io/components/tts.google/ | |
Note - this is a hack. It makes no attempt to update tests. It does not have all wavenet voices listed. | |
It attempts to respect language requests from hass but this has not been tested. | |
Google cloud gives the first 1 million characters of wavenet generation for free per month, if you exceed | |
that number they meter and charge. | |
For this to work, you must have downloaded a provisioning key from google as detailed in this quickstart: | |
https://cloud.google.com/text-to-speech/docs/quickstart-client-libraries | |
Follow the instructions there, and make sure that you have the GOOGLE_APPLICATION_CREDENTIALS environment variable | |
set properly before booting hass. | |
""" | |
import asyncio | |
import logging | |
import re | |
import async_timeout | |
import voluptuous as vol | |
import yarl | |
import os | |
from homeassistant.components.tts import CONF_LANG, PLATFORM_SCHEMA, Provider | |
from homeassistant.helpers.aiohttp_client import async_get_clientsession | |
from google.cloud import texttospeech | |
REQUIREMENTS = ['gTTS-token==1.1.3'] | |
_LOGGER = logging.getLogger(__name__) | |
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "your-key-path.json" | |
GOOGLE_SPEECH_URL = "https://translate.google.com/translate_tts" | |
MESSAGE_SIZE = 148 | |
SUPPORT_LANGUAGES = [ | |
'af', 'sq', 'ar', 'hy', 'bn', 'ca', 'zh', 'zh-cn', 'zh-tw', 'zh-yue', | |
'hr', 'cs', 'da', 'nl', 'en', 'en-au', 'en-uk', 'en-us', 'eo', 'fi', | |
'fr', 'de', 'el', 'hi', 'hu', 'is', 'id', 'it', 'ja', 'ko', 'la', 'lv', | |
'mk', 'no', 'pl', 'pt', 'pt-br', 'ro', 'ru', 'sr', 'sk', 'es', 'es-es', | |
'es-mx', 'es-us', 'sw', 'sv', 'ta', 'th', 'tr', 'vi', 'cy', 'uk', 'bg-BG' | |
] | |
DEFAULT_LANG = 'de' | |
WAVENET_LOOKUP = {'en': 'en-US-Wavenet-C', 'en-au': 'en-AU-Wavenet-A', | |
'en-uk': 'en-GB-Wavenet-A', 'fr': 'fr-FR-Wavenet-A', 'de': 'de-DE-Wavenet-A', | |
'it': 'it-IT-Wavenet-A', 'sv': 'sv-SE-Wavenet-A'} | |
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ | |
vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES), | |
}) | |
async def async_get_engine(hass, config): | |
"""Set up Google speech component.""" | |
return GoogleProvider(hass, config[CONF_LANG]) | |
class GoogleProvider(Provider): | |
"""The Google speech API provider.""" | |
def __init__(self, hass, lang): | |
"""Init Google TTS service.""" | |
self.hass = hass | |
self._lang = lang | |
self.name = 'Google' | |
self.client = texttospeech.TextToSpeechClient() | |
if self._lang in WAVENET_LOOKUP: | |
language_code = WAVENET_LOOKUP[self._lang] | |
else: | |
language_code = WAVENET_LOOKUP['de'] | |
self.voice = texttospeech.types.VoiceSelectionParams( | |
language_code=language_code) | |
self.audio_config = texttospeech.types.AudioConfig( | |
audio_encoding=texttospeech.enums.AudioEncoding.MP3) | |
@property | |
def default_language(self): | |
"""Return the default language.""" | |
return self._lang | |
@property | |
def supported_languages(self): | |
"""Return list of supported languages.""" | |
return SUPPORT_LANGUAGES | |
async def async_get_tts_audio(self, message, language, options=None): | |
"""Load TTS from google.""" | |
synthesis_input = texttospeech.types.SynthesisInput(text=message) | |
try: | |
with async_timeout.timeout(10, loop=self.hass.loop): | |
response = self.client.synthesize_speech( | |
synthesis_input, | |
self.voice, | |
self.audio_config) | |
return ("mp3", response.audio_content) | |
except Exception as e: | |
_LOGGER.error("Timeout for google speech. Or some other problem.", e) | |
return (None, None) | |
@staticmethod | |
def _split_message_to_parts(message): | |
"""Split message into single parts.""" | |
if len(message) <= MESSAGE_SIZE: | |
return [message] | |
punc = "!()[]?.,;:" | |
punc_list = [re.escape(c) for c in punc] | |
pattern = '|'.join(punc_list) | |
parts = re.split(pattern, message) | |
def split_by_space(fullstring): | |
"""Split a string by space.""" | |
if len(fullstring) > MESSAGE_SIZE: | |
idx = fullstring.rfind(' ', 0, MESSAGE_SIZE) | |
return [fullstring[:idx]] + split_by_space(fullstring[idx:]) | |
return [fullstring] | |
msg_parts = [] | |
for part in parts: | |
msg_parts += split_by_space(part) | |
return [msg for msg in msg_parts if len(msg) > 0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment