Skip to content

Instantly share code, notes, and snippets.

@ZanSara
Last active May 20, 2022 14:30
Show Gist options
  • Save ZanSara/068d2a39b7085a39a051d07c09f9cba4 to your computer and use it in GitHub Desktop.
Save ZanSara/068d2a39b7085a39a051d07c09f9cba4 to your computer and use it in GitHub Desktop.
text_to_speech_with_haystack_pipeline.py
import logging
from typing import TYPE_CHECKING, Union, List, Optional, Dict, Any, Tuple
import os
import hashlib
from pathlib import Path
from dataclasses import asdict
if not TYPE_CHECKING:
from pydantic.dataclasses import dataclass
else:
from dataclasses import dataclass # type: ignore # pylint: disable=ungrouped-imports
from haystack import BaseComponent, Answer
from espnet2.bin.tts_inference import Text2Speech
import soundfile as sf
@dataclass
class AudioAnswer(Answer):
answer: Path
context: Optional[Path] = None
offsets_in_document: Optional[Any] = None
offsets_in_context: Optional[Any] = None
def __str__(self):
return f"<AudioAnswer: answer='{self.answer}', score={self.score}, context='{self.context}'>"
def __repr__(self):
return f"<AudioAnswer {asdict(self)}>"
@dataclass
class GeneratedAudioAnswer(AudioAnswer):
type: str = "text-to-speech"
answer_transcript: Optional[str] = None
context_transcript: Optional[str] = None
@classmethod
def from_text_answer(cls, answer_object: Answer, generated_audio_answer: Any, generated_audio_context: Optional[Any] = None):
answer_dict = answer_object.to_dict()
answer_dict = {key: value for key, value in answer_dict.items() if value}
answer_dict["answer_transcript"] = answer_dict["answer"]
answer_dict["context_transcript"] = answer_dict["context"]
answer_dict["answer"] = generated_audio_answer
answer_dict["context"] = generated_audio_context
return cls(**answer_dict)
class TextToSpeech(BaseComponent):
outgoing_edges = 1
def __init__(
self,
model_name_or_path: Union[str, Path] = "espnet/kan-bayashi_ljspeech_vits",
generated_audio_path: Path = Path(__file__).parent / "generated_audio_answers"
):
super().__init__()
self.model = Text2Speech.from_pretrained(model_name_or_path)
self.generated_audio_path = generated_audio_path
if not os.path.exists(self.generated_audio_path):
os.mkdir(self.generated_audio_path)
def text_to_speech(self, text: str) -> Any:
filename = hashlib.md5(text.encode('utf-8')).hexdigest()
path = self.generated_audio_path / f"{filename}.wav"
# Duplicate answers might be in the list, in this case we save time by not regenerating the audio.
if not os.path.exists(path):
output = self.model(text)["wav"]
sf.write(path, output.numpy(), self.model.fs, "PCM_16")
return path
def run(self, answers: List[Answer]) -> Tuple[Dict[str, AudioAnswer], str]:
audio_answers = []
for answer in answers:
logging.info(f"Processing answer '{answer.answer}' and its context...")
answer_audio = self.text_to_speech(answer.answer)
context_audio = self.text_to_speech(answer.context)
audio_answer = GeneratedAudioAnswer.from_text_answer(
answer_object=answer,
generated_audio_answer=answer_audio,
generated_audio_context=context_audio,
)
audio_answer.type = "generative"
audio_answers.append(audio_answer)
return {"answers": audio_answers}, "output_1"
def run_batch(self, answers: List[Answer]) -> Tuple[Dict[str, AudioAnswer], str]:
return self.run(answers)
node = TextToSpeech()
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.pipelines import Pipeline
from haystack.nodes import EmbeddingRetriever, BM25Retriever
from haystack.nodes import FARMReader
document_store = ElasticsearchDocumentStore(
host="localhost",
username="",
password="",
index="document"
)
retriever = BM25Retriever(document_store=document_store)
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
pipeline = Pipeline()
pipeline.add_node(component=retriever, name="retriever", inputs=['Query'])
pipeline.add_node(component=reader, name="reader", inputs=['retriever'])
pipeline.add_node(component=node, name="text2speech", inputs=['reader'])
def ask_a_question(question, answers_count=5, retrieved_docs_count=10):
results = pipeline.run(query=question, params={"retriever": {"top_k": retrieved_docs_count}, "reader": {"top_k": answers_count}})
# This part simply prints the answers in a nice way.
print()
print()
print(f"Query: {question}\n")
for answer in results["answers"]:
print(f"""
- {answer.answer_transcript}: {answer.answer} (score: {answer.score:.2f})
From: {answer.meta['name']}
Context: {answer.context}\n""")
# Perfect outcome: the right answer is in the first position
ask_a_question("How many people live in Amsterdam?") # 872,680: https://en.wikipedia.org/wiki/Amsterdam
# Decent outcome: the right answer is not first, but appears in the top 5
ask_a_question("Which river crosses Rome?") # Tiber: https://en.wikipedia.org/wiki/Rome#Location
# Bad outcome: the right answer is not found, although it's present in the dataset
ask_a_question("What's the tallest mountain in the world?") # Mount Everest: https://en.wikipedia.org/wiki/Nepal
# A more difficult question:
ask_a_question("How many people live in the capital of France?") # 872,680: https://en.wikipedia.org/wiki/Paris
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment