Skip to content

Instantly share code, notes, and snippets.

@willwade
Created June 30, 2024 15:48
Show Gist options
  • Save willwade/316279b41856a3e101d27e7806b77132 to your computer and use it in GitHub Desktop.
Save willwade/316279b41856a3e101d27e7806b77132 to your computer and use it in GitHub Desktop.
import sounddevice as sd
import time
import logging
import threading
import wave
from abc import ABC, abstractmethod
from typing import Any, List, Literal, Optional, Union, Dict, Callable
FileFormat = Union[Literal["wav"], Literal["mp3"]]
class AbstractTTS(ABC):
def __init__(self):
self.voice_id = None
self.audio_rate = 22050
self.audio_bytes = None
self.playing = threading.Event()
self.playing.clear() # Not playing by default
self.position = 0 # Position in the byte stream
self.timings = []
self.timers = []
self.properties = {'volume': "", 'rate': "", 'pitch': ""}
self.callbacks = {'onStart': None, 'onEnd': None, 'started-word': None}
self.stream_lock = threading.Lock()
@abstractmethod
def get_voices(self) -> List[Dict[str, Any]]:
pass
def set_voice(self, voice_id: str, lang: str = "en-US"):
self.voice_id = voice_id
self.lang = lang
@classmethod
@abstractmethod
def supported_formats(cls) -> List[FileFormat]:
pass
@abstractmethod
def synth_to_bytes(self, text: Any, format: Optional[FileFormat] = "wav") -> bytes:
pass
def synth_to_file(self, text: Any, filename: str, format: Optional[FileFormat] = None) -> None:
audio_content = self.synth_to_bytes(text, format=format or "wav")
channels = 1
sample_width = 2
with wave.open(filename, "wb") as file:
file.setnchannels(channels)
file.setsampwidth(sample_width)
file.setframerate(self.audio_rate)
file.writeframes(audio_content)
def synth(self, text: str, filename: str, format: Optional[FileFormat] = "wav"):
self.synth_to_file(text, filename, format)
def speak(self, text: Any, format: Optional[FileFormat] = "wav") -> bytes:
try:
audio_bytes = self.synth_to_bytes(text, format)
audio_bytes = self.apply_fade_in(audio_bytes)
# Convert byte data to a format suitable for sounddevice
audio_data = self.bytes_to_samples(audio_bytes)
sd.play(audio_data, samplerate=self.audio_rate)
sd.wait()
except Exception as e:
logging.error(f"Error playing audio: {e}")
def apply_fade_in(self, audio_bytes, fade_duration_ms=50, sample_rate=22050):
num_fade_samples = int(fade_duration_ms * sample_rate / 1000)
fade_in = [i / num_fade_samples for i in range(num_fade_samples)]
audio_samples = self.bytes_to_samples(audio_bytes)
for i in range(min(len(audio_samples), num_fade_samples)):
audio_samples[i] = int(audio_samples[i] * fade_in[i])
faded_audio_bytes = self.samples_to_bytes(audio_samples)
return faded_audio_bytes
def speak_streamed(self, text: Any, format: Optional[FileFormat] = "wav"):
try:
audio_bytes = self.synth_to_bytes(text, format)
if not isinstance(audio_bytes, (bytes, bytearray)):
raise ValueError("Synthesized speech is not in bytes format")
except Exception as e:
logging.error(f"Error synthesizing speech: {e}")
return
self.audio_bytes = self.apply_fade_in(audio_bytes)
self.position = 0
self.playing.set()
self._trigger_callback('onStart')
try:
self.play_thread = threading.Thread(target=self._start_stream)
self.play_thread.start()
except Exception as e:
logging.error(f"Failed to play audio: {e}")
raise
def _start_stream(self):
try:
audio_data = self.bytes_to_samples(self.audio_bytes)
stream = sd.OutputStream(
samplerate=self.audio_rate,
channels=1,
dtype='int16',
callback=self.callback
)
with stream:
stream.start()
while self.playing.is_set() and stream.active:
time.sleep(0.1)
stream.stop()
except Exception as e:
logging.error(f"Failed to start stream: {e}")
def callback(self, outdata, frames, time, status):
if self.playing.is_set():
end_position = self.position + frames * 2
outdata[:] = self.audio_bytes[self.position:end_position]
self.position = end_position
if self.position >= len(self.audio_bytes):
self._trigger_callback('onEnd')
self.playing.clear()
else:
outdata.fill(0)
def pause_audio(self):
self.playing.clear()
def resume_audio(self):
self.playing.set()
if not self.stream:
self.setup_stream()
if self.stream and not self.stream.is_active():
self.stream.start()
def stop_audio(self):
self.playing.clear()
if self.play_thread and self.play_thread.is_alive():
self.play_thread.join()
for timer in self.timers:
timer.cancel()
self.timers.clear()
def connect(self, event_name: str, callback: Callable):
if event_name in self.callbacks:
self.callbacks[event_name] = callback
def _trigger_callback(self, event_name: str, *args):
if event_name in self.callbacks and self.callbacks[event_name] is not None:
self.callbacks[event_name](*args)
def bytes_to_samples(self, audio_bytes):
"""Convert bytes to a list of samples (int16)."""
return [int.from_bytes(audio_bytes[i:i+2], 'little', signed=True) for i in range(0, len(audio_bytes), 2)]
def samples_to_bytes(self, samples):
"""Convert a list of samples (int16) to bytes."""
return b''.join((sample.to_bytes(2, 'little', signed=True) for sample in samples))
class DummyTTS(AbstractTTS):
"""Dummy TTS implementation for demonstration purposes."""
def get_voices(self) -> List[Dict[str, Any]]:
return [{"id": "dummy_voice", "name": "Dummy Voice", "lang": "en-US"}]
@classmethod
def supported_formats(cls) -> List[FileFormat]:
return ["wav"]
def synth_to_bytes(self, text: Any, format: Optional[FileFormat] = "wav") -> bytes:
# Dummy implementation: create a simple waveform for demonstration
duration = 5 # seconds
frequency = 440.0 # A4
audio_data = []
for i in range(int(self.audio_rate * duration)):
sample = int(32767 * 0.5 * (1 + math.sin(2 * math.pi * frequency * i / self.audio_rate)))
audio_data.append(sample)
return self.samples_to_bytes(audio_data)
# Instantiate the DummyTTS
tts = DummyTTS()
# Callback functions
def on_start():
print("Playback started.")
def on_end():
print("Playback ended.")
# Connect callbacks
tts.connect('onStart', on_start)
tts.connect('onEnd', on_end)
# Synthesize and play text
print("Playing synthesized speech...")
tts.speak_streamed("Hello, this is a test.")
# Pause and resume playback
time.sleep(2)
print("Pausing playback...")
tts.pause_audio()
time.sleep(2)
print("Resuming playback...")
tts.resume_audio()
# Save synthesized speech to file
print("Saving synthesized speech to file...")
tts.synth("Hello, this is a test.", "output.wav")
# Wait for playback to finish
time.sleep(10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment