Last active
May 20, 2022 14:30
-
-
Save ZanSara/068d2a39b7085a39a051d07c09f9cba4 to your computer and use it in GitHub Desktop.
text_to_speech_with_haystack_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from typing import TYPE_CHECKING, Union, List, Optional, Dict, Any, Tuple | |
import os | |
import hashlib | |
from pathlib import Path | |
from dataclasses import asdict | |
if not TYPE_CHECKING: | |
from pydantic.dataclasses import dataclass | |
else: | |
from dataclasses import dataclass # type: ignore # pylint: disable=ungrouped-imports | |
from haystack import BaseComponent, Answer | |
from espnet2.bin.tts_inference import Text2Speech | |
import soundfile as sf | |
@dataclass | |
class AudioAnswer(Answer): | |
answer: Path | |
context: Optional[Path] = None | |
offsets_in_document: Optional[Any] = None | |
offsets_in_context: Optional[Any] = None | |
def __str__(self): | |
return f"<AudioAnswer: answer='{self.answer}', score={self.score}, context='{self.context}'>" | |
def __repr__(self): | |
return f"<AudioAnswer {asdict(self)}>" | |
@dataclass | |
class GeneratedAudioAnswer(AudioAnswer): | |
type: str = "text-to-speech" | |
answer_transcript: Optional[str] = None | |
context_transcript: Optional[str] = None | |
@classmethod | |
def from_text_answer(cls, answer_object: Answer, generated_audio_answer: Any, generated_audio_context: Optional[Any] = None): | |
answer_dict = answer_object.to_dict() | |
answer_dict = {key: value for key, value in answer_dict.items() if value} | |
answer_dict["answer_transcript"] = answer_dict["answer"] | |
answer_dict["context_transcript"] = answer_dict["context"] | |
answer_dict["answer"] = generated_audio_answer | |
answer_dict["context"] = generated_audio_context | |
return cls(**answer_dict) | |
class TextToSpeech(BaseComponent): | |
outgoing_edges = 1 | |
def __init__( | |
self, | |
model_name_or_path: Union[str, Path] = "espnet/kan-bayashi_ljspeech_vits", | |
generated_audio_path: Path = Path(__file__).parent / "generated_audio_answers" | |
): | |
super().__init__() | |
self.model = Text2Speech.from_pretrained(model_name_or_path) | |
self.generated_audio_path = generated_audio_path | |
if not os.path.exists(self.generated_audio_path): | |
os.mkdir(self.generated_audio_path) | |
def text_to_speech(self, text: str) -> Any: | |
filename = hashlib.md5(text.encode('utf-8')).hexdigest() | |
path = self.generated_audio_path / f"{filename}.wav" | |
# Duplicate answers might be in the list, in this case we save time by not regenerating the audio. | |
if not os.path.exists(path): | |
output = self.model(text)["wav"] | |
sf.write(path, output.numpy(), self.model.fs, "PCM_16") | |
return path | |
def run(self, answers: List[Answer]) -> Tuple[Dict[str, AudioAnswer], str]: | |
audio_answers = [] | |
for answer in answers: | |
logging.info(f"Processing answer '{answer.answer}' and its context...") | |
answer_audio = self.text_to_speech(answer.answer) | |
context_audio = self.text_to_speech(answer.context) | |
audio_answer = GeneratedAudioAnswer.from_text_answer( | |
answer_object=answer, | |
generated_audio_answer=answer_audio, | |
generated_audio_context=context_audio, | |
) | |
audio_answer.type = "generative" | |
audio_answers.append(audio_answer) | |
return {"answers": audio_answers}, "output_1" | |
def run_batch(self, answers: List[Answer]) -> Tuple[Dict[str, AudioAnswer], str]: | |
return self.run(answers) | |
node = TextToSpeech() | |
from haystack.document_stores import ElasticsearchDocumentStore | |
from haystack.pipelines import Pipeline | |
from haystack.nodes import EmbeddingRetriever, BM25Retriever | |
from haystack.nodes import FARMReader | |
document_store = ElasticsearchDocumentStore( | |
host="localhost", | |
username="", | |
password="", | |
index="document" | |
) | |
retriever = BM25Retriever(document_store=document_store) | |
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True) | |
pipeline = Pipeline() | |
pipeline.add_node(component=retriever, name="retriever", inputs=['Query']) | |
pipeline.add_node(component=reader, name="reader", inputs=['retriever']) | |
pipeline.add_node(component=node, name="text2speech", inputs=['reader']) | |
def ask_a_question(question, answers_count=5, retrieved_docs_count=10): | |
results = pipeline.run(query=question, params={"retriever": {"top_k": retrieved_docs_count}, "reader": {"top_k": answers_count}}) | |
# This part simply prints the answers in a nice way. | |
print() | |
print() | |
print(f"Query: {question}\n") | |
for answer in results["answers"]: | |
print(f""" | |
- {answer.answer_transcript}: {answer.answer} (score: {answer.score:.2f}) | |
From: {answer.meta['name']} | |
Context: {answer.context}\n""") | |
# Perfect outcome: the right answer is in the first position | |
ask_a_question("How many people live in Amsterdam?") # 872,680: https://en.wikipedia.org/wiki/Amsterdam | |
# Decent outcome: the right answer is not first, but appears in the top 5 | |
ask_a_question("Which river crosses Rome?") # Tiber: https://en.wikipedia.org/wiki/Rome#Location | |
# Bad outcome: the right answer is not found, although it's present in the dataset | |
ask_a_question("What's the tallest mountain in the world?") # Mount Everest: https://en.wikipedia.org/wiki/Nepal | |
# A more difficult question: | |
ask_a_question("How many people live in the capital of France?") # 872,680: https://en.wikipedia.org/wiki/Paris |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment