Last active
February 11, 2024 22:06
-
-
Save seang96/cd6b7d9184f654d67fe4d9febcd5b81e to your computer and use it in GitHub Desktop.
Home Assistant ESPHome Voice Assistant Port Fix Update (Not Working)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""ESPHome voice assistant support.""" | |
from __future__ import annotations | |
import asyncio | |
from collections.abc import AsyncIterable, Callable | |
import logging | |
import socket | |
from typing import cast | |
from aioesphomeapi import ( | |
VoiceAssistantAudioSettings, | |
VoiceAssistantCommandFlag, | |
VoiceAssistantEventType, | |
) | |
from homeassistant.components import stt, tts | |
from homeassistant.components.assist_pipeline import ( | |
AudioSettings, | |
PipelineEvent, | |
PipelineEventType, | |
PipelineNotFound, | |
PipelineStage, | |
WakeWordSettings, | |
async_pipeline_from_audio_stream, | |
select as pipeline_select, | |
) | |
from homeassistant.components.assist_pipeline.error import ( | |
WakeWordDetectionAborted, | |
WakeWordDetectionError, | |
) | |
from homeassistant.components.media_player import async_process_play_media_url | |
from homeassistant.core import Context, HomeAssistant, callback | |
from .const import DOMAIN | |
from .entry_data import RuntimeEntryData | |
from .enum_mapper import EsphomeEnumMapper | |
_LOGGER = logging.getLogger(__name__) | |
UDP_PORT = 9124 # Set to 0 to let the OS pick a free random port | |
UDP_MAX_PACKET_SIZE = 1024 | |
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[ | |
VoiceAssistantEventType, PipelineEventType | |
] = EsphomeEnumMapper( | |
{ | |
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR, | |
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START, | |
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END, | |
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START, | |
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END, | |
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START, | |
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END, | |
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START, | |
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END, | |
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START, | |
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END, | |
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START, | |
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END, | |
} | |
) | |
class VoiceAssistantUDPServer(asyncio.DatagramProtocol): | |
"""Receive UDP packets and forward them to the voice assistant.""" | |
started = False | |
stopped = False | |
transport: asyncio.DatagramTransport | None = None | |
remote_addr: tuple[str, int] | None = None | |
def __init__( | |
self, | |
hass: HomeAssistant, | |
entry_data: RuntimeEntryData, | |
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], | |
handle_finished: Callable[[], None], | |
) -> None: | |
"""Initialize UDP receiver.""" | |
self.context = Context() | |
self.hass = hass | |
assert entry_data.device_info is not None | |
self.entry_data = entry_data | |
self.device_info = entry_data.device_info | |
self.queue: asyncio.Queue[bytes] = asyncio.Queue() | |
self.handle_event = handle_event | |
self.handle_finished = handle_finished | |
self._tts_done = asyncio.Event() | |
async def start_server(self) -> int: | |
"""Start accepting connections.""" | |
def accept_connection() -> VoiceAssistantUDPServer: | |
"""Accept connection.""" | |
if self.started: | |
raise RuntimeError("Can only start once") | |
if self.stopped: | |
raise RuntimeError("No longer accepting connections") | |
self.started = True | |
return self | |
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
sock.setblocking(False) | |
for counter in range(8): | |
try: | |
sock.bind(("", UDP_PORT + counter)) | |
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
break | |
except OSError: | |
_LOGGER.debug("Tried port %i. Aoready in use.", UDP_PORT + counter) | |
pass | |
if counter > 8: | |
raise RuntimeError("cannot find a free port within range") | |
await asyncio.get_running_loop().create_datagram_endpoint( | |
accept_connection, sock=sock | |
) | |
return cast(int, sock.getsockname()[1]) | |
@callback | |
def connection_made(self, transport: asyncio.BaseTransport) -> None: | |
"""Store transport for later use.""" | |
self.transport = cast(asyncio.DatagramTransport, transport) | |
@callback | |
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: | |
"""Handle incoming UDP packet.""" | |
if not self.started or self.stopped: | |
return | |
if self.remote_addr is None: | |
self.remote_addr = addr | |
self.queue.put_nowait(data) | |
def error_received(self, exc: Exception) -> None: | |
"""Handle when a send or receive operation raises an OSError. | |
(Other than BlockingIOError or InterruptedError.) | |
""" | |
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc) | |
self.handle_finished() | |
@callback | |
def stop(self) -> None: | |
"""Stop the receiver.""" | |
self.queue.put_nowait(b"") | |
self.started = False | |
self.stopped = True | |
def close(self) -> None: | |
"""Close the receiver.""" | |
self.started = False | |
self.stopped = True | |
if self.transport is not None: | |
self.transport.close() | |
async def _iterate_packets(self) -> AsyncIterable[bytes]: | |
"""Iterate over incoming packets.""" | |
if not self.started or self.stopped: | |
raise RuntimeError("Not running") | |
while data := await self.queue.get(): | |
yield data | |
def _event_callback(self, event: PipelineEvent) -> None: | |
"""Handle pipeline events.""" | |
try: | |
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type) | |
except KeyError: | |
_LOGGER.debug("Received unknown pipeline event type: %s", event.type) | |
return | |
data_to_send = None | |
error = False | |
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: | |
self.entry_data.async_set_assist_pipeline_state(True) | |
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: | |
assert event.data is not None | |
data_to_send = {"text": event.data["stt_output"]["text"]} | |
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: | |
assert event.data is not None | |
data_to_send = { | |
"conversation_id": event.data["intent_output"]["conversation_id"] or "", | |
} | |
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: | |
assert event.data is not None | |
data_to_send = {"text": event.data["tts_input"]} | |
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: | |
assert event.data is not None | |
path = event.data["tts_output"]["url"] | |
url = async_process_play_media_url(self.hass, path) | |
data_to_send = {"url": url} | |
if self.device_info.voice_assistant_version >= 2: | |
media_id = event.data["tts_output"]["media_id"] | |
self.hass.async_create_background_task( | |
self._send_tts(media_id), "esphome_voice_assistant_tts" | |
) | |
else: | |
self._tts_done.set() | |
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: | |
assert event.data is not None | |
if not event.data["wake_word_output"]: | |
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR | |
data_to_send = { | |
"code": "no_wake_word", | |
"message": "No wake word detected", | |
} | |
error = True | |
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: | |
assert event.data is not None | |
data_to_send = { | |
"code": event.data["code"], | |
"message": event.data["message"], | |
} | |
error = True | |
self.handle_event(event_type, data_to_send) | |
if error: | |
self._tts_done.set() | |
self.handle_finished() | |
async def run_pipeline( | |
self, | |
device_id: str, | |
conversation_id: str | None, | |
flags: int = 0, | |
audio_settings: VoiceAssistantAudioSettings | None = None, | |
) -> None: | |
"""Run the Voice Assistant pipeline.""" | |
if audio_settings is None or audio_settings.volume_multiplier == 0: | |
audio_settings = VoiceAssistantAudioSettings() | |
tts_audio_output = ( | |
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3" | |
) | |
_LOGGER.debug("Starting pipeline") | |
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD: | |
start_stage = PipelineStage.WAKE_WORD | |
else: | |
start_stage = PipelineStage.STT | |
try: | |
await async_pipeline_from_audio_stream( | |
self.hass, | |
context=self.context, | |
event_callback=self._event_callback, | |
stt_metadata=stt.SpeechMetadata( | |
language="", # set in async_pipeline_from_audio_stream | |
format=stt.AudioFormats.WAV, | |
codec=stt.AudioCodecs.PCM, | |
bit_rate=stt.AudioBitRates.BITRATE_16, | |
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, | |
channel=stt.AudioChannels.CHANNEL_MONO, | |
), | |
stt_stream=self._iterate_packets(), | |
pipeline_id=pipeline_select.get_chosen_pipeline( | |
self.hass, DOMAIN, self.device_info.mac_address | |
), | |
conversation_id=conversation_id, | |
device_id=device_id, | |
tts_audio_output=tts_audio_output, | |
start_stage=start_stage, | |
wake_word_settings=WakeWordSettings(timeout=5), | |
audio_settings=AudioSettings( | |
noise_suppression_level=audio_settings.noise_suppression_level, | |
auto_gain_dbfs=audio_settings.auto_gain, | |
volume_multiplier=audio_settings.volume_multiplier, | |
is_vad_enabled=bool(flags & VoiceAssistantCommandFlag.USE_VAD), | |
), | |
) | |
# Block until TTS is done sending | |
await self._tts_done.wait() | |
_LOGGER.debug("Pipeline finished") | |
except PipelineNotFound: | |
self.handle_event( | |
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, | |
{ | |
"code": "pipeline not found", | |
"message": "Selected pipeline not found", | |
}, | |
) | |
_LOGGER.warning("Pipeline not found") | |
except WakeWordDetectionAborted: | |
pass # Wake word detection was aborted and `handle_finished` is enough. | |
except WakeWordDetectionError as e: | |
self.handle_event( | |
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, | |
{ | |
"code": e.code, | |
"message": e.message, | |
}, | |
) | |
finally: | |
self.handle_finished() | |
async def _send_tts(self, media_id: str) -> None: | |
"""Send TTS audio to device via UDP.""" | |
try: | |
if self.transport is None: | |
return | |
self.handle_event( | |
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {} | |
) | |
_extension, audio_bytes = await tts.async_get_media_source_audio( | |
self.hass, | |
media_id, | |
) | |
_LOGGER.debug("Sending %d bytes of audio", len(audio_bytes)) | |
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8 | |
sample_offset = 0 | |
samples_left = len(audio_bytes) // bytes_per_sample | |
while samples_left > 0: | |
bytes_offset = sample_offset * bytes_per_sample | |
chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024] | |
samples_in_chunk = len(chunk) // bytes_per_sample | |
samples_left -= samples_in_chunk | |
_LOGGER.warn("Sending %d of chunk to %s", len(chunk), self.remote_addr) | |
self.transport.sendto(chunk, self.remote_addr) | |
await asyncio.sleep( | |
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9 | |
) | |
sample_offset += samples_in_chunk | |
finally: | |
self.handle_event( | |
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {} | |
) | |
self._tts_done.set() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment