Created
June 30, 2024 15:48
-
-
Save willwade/316279b41856a3e101d27e7806b77132 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sounddevice as sd | |
import time | |
import logging | |
import threading | |
import wave | |
from abc import ABC, abstractmethod | |
from typing import Any, List, Literal, Optional, Union, Dict, Callable | |
FileFormat = Union[Literal["wav"], Literal["mp3"]] | |
class AbstractTTS(ABC): | |
def __init__(self): | |
self.voice_id = None | |
self.audio_rate = 22050 | |
self.audio_bytes = None | |
self.playing = threading.Event() | |
self.playing.clear() # Not playing by default | |
self.position = 0 # Position in the byte stream | |
self.timings = [] | |
self.timers = [] | |
self.properties = {'volume': "", 'rate': "", 'pitch': ""} | |
self.callbacks = {'onStart': None, 'onEnd': None, 'started-word': None} | |
self.stream_lock = threading.Lock() | |
@abstractmethod | |
def get_voices(self) -> List[Dict[str, Any]]: | |
pass | |
def set_voice(self, voice_id: str, lang: str = "en-US"): | |
self.voice_id = voice_id | |
self.lang = lang | |
@classmethod | |
@abstractmethod | |
def supported_formats(cls) -> List[FileFormat]: | |
pass | |
@abstractmethod | |
def synth_to_bytes(self, text: Any, format: Optional[FileFormat] = "wav") -> bytes: | |
pass | |
def synth_to_file(self, text: Any, filename: str, format: Optional[FileFormat] = None) -> None: | |
audio_content = self.synth_to_bytes(text, format=format or "wav") | |
channels = 1 | |
sample_width = 2 | |
with wave.open(filename, "wb") as file: | |
file.setnchannels(channels) | |
file.setsampwidth(sample_width) | |
file.setframerate(self.audio_rate) | |
file.writeframes(audio_content) | |
def synth(self, text: str, filename: str, format: Optional[FileFormat] = "wav"): | |
self.synth_to_file(text, filename, format) | |
def speak(self, text: Any, format: Optional[FileFormat] = "wav") -> bytes: | |
try: | |
audio_bytes = self.synth_to_bytes(text, format) | |
audio_bytes = self.apply_fade_in(audio_bytes) | |
# Convert byte data to a format suitable for sounddevice | |
audio_data = self.bytes_to_samples(audio_bytes) | |
sd.play(audio_data, samplerate=self.audio_rate) | |
sd.wait() | |
except Exception as e: | |
logging.error(f"Error playing audio: {e}") | |
def apply_fade_in(self, audio_bytes, fade_duration_ms=50, sample_rate=22050): | |
num_fade_samples = int(fade_duration_ms * sample_rate / 1000) | |
fade_in = [i / num_fade_samples for i in range(num_fade_samples)] | |
audio_samples = self.bytes_to_samples(audio_bytes) | |
for i in range(min(len(audio_samples), num_fade_samples)): | |
audio_samples[i] = int(audio_samples[i] * fade_in[i]) | |
faded_audio_bytes = self.samples_to_bytes(audio_samples) | |
return faded_audio_bytes | |
def speak_streamed(self, text: Any, format: Optional[FileFormat] = "wav"): | |
try: | |
audio_bytes = self.synth_to_bytes(text, format) | |
if not isinstance(audio_bytes, (bytes, bytearray)): | |
raise ValueError("Synthesized speech is not in bytes format") | |
except Exception as e: | |
logging.error(f"Error synthesizing speech: {e}") | |
return | |
self.audio_bytes = self.apply_fade_in(audio_bytes) | |
self.position = 0 | |
self.playing.set() | |
self._trigger_callback('onStart') | |
try: | |
self.play_thread = threading.Thread(target=self._start_stream) | |
self.play_thread.start() | |
except Exception as e: | |
logging.error(f"Failed to play audio: {e}") | |
raise | |
def _start_stream(self): | |
try: | |
audio_data = self.bytes_to_samples(self.audio_bytes) | |
stream = sd.OutputStream( | |
samplerate=self.audio_rate, | |
channels=1, | |
dtype='int16', | |
callback=self.callback | |
) | |
with stream: | |
stream.start() | |
while self.playing.is_set() and stream.active: | |
time.sleep(0.1) | |
stream.stop() | |
except Exception as e: | |
logging.error(f"Failed to start stream: {e}") | |
def callback(self, outdata, frames, time, status): | |
if self.playing.is_set(): | |
end_position = self.position + frames * 2 | |
outdata[:] = self.audio_bytes[self.position:end_position] | |
self.position = end_position | |
if self.position >= len(self.audio_bytes): | |
self._trigger_callback('onEnd') | |
self.playing.clear() | |
else: | |
outdata.fill(0) | |
def pause_audio(self): | |
self.playing.clear() | |
def resume_audio(self): | |
self.playing.set() | |
if not self.stream: | |
self.setup_stream() | |
if self.stream and not self.stream.is_active(): | |
self.stream.start() | |
def stop_audio(self): | |
self.playing.clear() | |
if self.play_thread and self.play_thread.is_alive(): | |
self.play_thread.join() | |
for timer in self.timers: | |
timer.cancel() | |
self.timers.clear() | |
def connect(self, event_name: str, callback: Callable): | |
if event_name in self.callbacks: | |
self.callbacks[event_name] = callback | |
def _trigger_callback(self, event_name: str, *args): | |
if event_name in self.callbacks and self.callbacks[event_name] is not None: | |
self.callbacks[event_name](*args) | |
def bytes_to_samples(self, audio_bytes): | |
"""Convert bytes to a list of samples (int16).""" | |
return [int.from_bytes(audio_bytes[i:i+2], 'little', signed=True) for i in range(0, len(audio_bytes), 2)] | |
def samples_to_bytes(self, samples): | |
"""Convert a list of samples (int16) to bytes.""" | |
return b''.join((sample.to_bytes(2, 'little', signed=True) for sample in samples)) | |
class DummyTTS(AbstractTTS): | |
"""Dummy TTS implementation for demonstration purposes.""" | |
def get_voices(self) -> List[Dict[str, Any]]: | |
return [{"id": "dummy_voice", "name": "Dummy Voice", "lang": "en-US"}] | |
@classmethod | |
def supported_formats(cls) -> List[FileFormat]: | |
return ["wav"] | |
def synth_to_bytes(self, text: Any, format: Optional[FileFormat] = "wav") -> bytes: | |
# Dummy implementation: create a simple waveform for demonstration | |
duration = 5 # seconds | |
frequency = 440.0 # A4 | |
audio_data = [] | |
for i in range(int(self.audio_rate * duration)): | |
sample = int(32767 * 0.5 * (1 + math.sin(2 * math.pi * frequency * i / self.audio_rate))) | |
audio_data.append(sample) | |
return self.samples_to_bytes(audio_data) | |
# Instantiate the DummyTTS | |
tts = DummyTTS() | |
# Callback functions | |
def on_start(): | |
print("Playback started.") | |
def on_end(): | |
print("Playback ended.") | |
# Connect callbacks | |
tts.connect('onStart', on_start) | |
tts.connect('onEnd', on_end) | |
# Synthesize and play text | |
print("Playing synthesized speech...") | |
tts.speak_streamed("Hello, this is a test.") | |
# Pause and resume playback | |
time.sleep(2) | |
print("Pausing playback...") | |
tts.pause_audio() | |
time.sleep(2) | |
print("Resuming playback...") | |
tts.resume_audio() | |
# Save synthesized speech to file | |
print("Saving synthesized speech to file...") | |
tts.synth("Hello, this is a test.", "output.wav") | |
# Wait for playback to finish | |
time.sleep(10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment