|
import bs4 |
|
import configparser |
|
from datasets import load_dataset |
|
import diffusers |
|
from duckduckgo_search import ddg |
|
import imgcat |
|
import json |
|
import openai |
|
import os |
|
import pvrecorder as pv |
|
import requests |
|
import select |
|
import soundfile as sf |
|
import struct |
|
import subprocess |
|
import sys |
|
import tempfile |
|
import textwrap |
|
import time |
|
import threading |
|
import torch |
|
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan |
|
from typing import Optional, List, Tuple |
|
import wave |
|
import whisper |
|
|
|
|
|
def retry(n=3): |
|
def wrapper(func): |
|
def inner(*args, **kwargs): |
|
for _ in range(n): |
|
response = func(*args, **kwargs) |
|
if response is not None: |
|
return response |
|
else: |
|
time.sleep(.5) |
|
raise ValueError(f"Tried {n} times with no success") |
|
return inner |
|
return wrapper |
|
|
|
|
|
class Crystal: |
|
def __init__(self, audio_device: Optional[int] = None, model: str = "gpt-3.5-turbo"): |
|
self.__audio_dec = whisper.load_model("base") |
|
self.__model = model |
|
|
|
# Speech-to-text |
|
self.__processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") |
|
self.__t2s_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") |
|
self.__vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") |
|
|
|
# load xvector containing speaker's voice characteristics from a dataset |
|
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") |
|
self.__speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) |
|
|
|
if audio_device is None: |
|
self.__device_index = self._get_audio_input() |
|
else: |
|
self.__device_index = audio_device |
|
|
|
self.tokens = { |
|
"prompt_tokens": 0, |
|
"completion_tokens": 0, |
|
"total_tokens": 0, |
|
} |
|
|
|
self.__init_stable_diffusion() |
|
self.__chat_messages = [] |
|
|
|
def _add_tokens(self, response): |
|
self.tokens["prompt_tokens"] += response["usage"]["prompt_tokens"] |
|
self.tokens["completion_tokens"] += response["usage"]["completion_tokens"] |
|
self.tokens["total_tokens"] += response["usage"]["total_tokens"] |
|
messages = [m["message"].to_dict_recursive() for m in response["choices"]] |
|
self.__chat_messages.extend(messages) |
|
|
|
@staticmethod |
|
def _get_audio_input(): |
|
print("Here are your audio devices:") |
|
for index, device in enumerate(pv.PvRecorder.get_audio_devices()): |
|
print(f"[{index}] {device}") |
|
device = input("Which should we use? [default] ") |
|
try: |
|
return int(device) |
|
except: |
|
return -1 |
|
|
|
def listen(self) -> str: |
|
print("Press any key to stop recording: ", end="") |
|
try: |
|
recorder = pv.PvRecorder(device_index=self.__device_index, frame_length=512) |
|
except: |
|
print(f"index: {self.__device_index}, devices: {pv.PvRecorder.get_audio_devices()}") |
|
raise |
|
audio = [] |
|
|
|
recorder.start() |
|
|
|
while not self._has_any_key(): |
|
frame = recorder.read() |
|
audio.extend(frame) |
|
|
|
_ = input() |
|
recorder.stop() |
|
temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False) |
|
with wave.open(temp_file, 'w') as f: |
|
f.setparams((1, 2, 16000, 512, "NONE", "NONE")) |
|
f.writeframes(struct.pack("h" * len(audio), * audio)) |
|
|
|
recorder.delete() |
|
print() |
|
print("audio path:", temp_file.name) |
|
print() |
|
return temp_file.name |
|
|
|
@staticmethod |
|
def _has_any_key(): |
|
return select.select([sys.stdin,],[],[],0.0)[0] |
|
|
|
|
|
def get_text(self, audio_file: str): |
|
# load audio and pad/trim it to fit 30 seconds |
|
audio = whisper.load_audio(audio_file) |
|
audio = whisper.pad_or_trim(audio) |
|
|
|
# make log-Mel spectrogram and move to the same device as the model |
|
mel = whisper.log_mel_spectrogram(audio).to(self.__audio_dec.device) |
|
|
|
# detect the spoken language |
|
#_, probs = self.__audio_dec.detect_language(mel) |
|
#print(f"Detected language: {max(probs, key=probs.get)}") |
|
|
|
# decode the audio |
|
options = whisper.DecodingOptions(fp16 = False) |
|
result = whisper.decode(self.__audio_dec, mel, options) |
|
return result.text |
|
|
|
def speak(self, path: str): |
|
thread = None |
|
action = Actions.parse_file(path) |
|
if action.caption is not None: |
|
with open(path, "w") as handle: |
|
print(action.modified_input, file=handle) |
|
|
|
thread = threading.Thread(target=lambda: self.generate_image(action.caption)) |
|
thread.start() |
|
elif action.fact is not None: |
|
path = self.ddg_reprompt(action.modified_input, action.fact) |
|
|
|
try: |
|
result = subprocess.run(["say", "-f", path], capture_output=True) |
|
#self.speak_via_huggingface(action) |
|
except KeyboardInterrupt: |
|
pass |
|
|
|
if thread is not None: |
|
thread.join() |
|
|
|
def speak_via_huggingface(self, action: 'Actions'): |
|
with tempfile.NamedTemporaryFile(suffix='.wav', mode="wb", delete=True) as handle: |
|
inputs = self.__processor(text=action.modified_input, return_tensors="pt") |
|
frame_size = self.__speaker_embeddings.size()[1] |
|
for i_page in range(inputs["input_ids"].size()[1] // frame_size): |
|
page = inputs["input_ids"][:, (i_page * frame_size):((i_page+1) * frame_size)].to(torch.int64) |
|
speech = self.__t2s_model.generate_speech(page, self.__speaker_embeddings, vocoder=self.__vocoder) |
|
sf.write(handle, speech.numpy(), samplerate=16000) |
|
result = subprocess.run(["afplay", handle.name], capture_output=True) |
|
|
|
|
|
def ddg_reprompt(self, orig_prompt, query, n_articles: int = 1): |
|
context = self.get_wikipedia_context(query, n_articles) |
|
new_prompt = textwrap.dedent(f"""\ |
|
Given this context about {query}: |
|
|
|
{context} |
|
|
|
Answer this and add more intereresting details, be creative: {orig_prompt} |
|
""") |
|
print() |
|
print() |
|
print("New Prompt:") |
|
path = self.__send_chatgpt_msg(new_prompt) |
|
with open(path) as handle: |
|
print(handle.read()) |
|
return path |
|
|
|
@staticmethod |
|
def get_wikipedia_context(query: str, n_articles: int = 1) -> str: |
|
query = f"{query} site:wikipedia.org" |
|
result = [] |
|
for href in [r["href"] for r in ddg(query)][:n_articles]: |
|
print("Get", href) |
|
text = requests.get(href).content |
|
soup = bs4.BeautifulSoup(text, "html.parser") |
|
|
|
# Remove non-visible elements (e.g. script, style, etc.) |
|
for element in soup(["script", "style", "head", "title", "meta", "[document]"]): |
|
element.extract() |
|
|
|
def remove_trailing_sections(soup, text: str): |
|
# Find the h2 element with text "See also" and remove all subsequent content |
|
target_span = soup.find("span", string=text) |
|
if target_span is None: |
|
return |
|
target_h2 = target_span.find_parent("h2") |
|
if target_h2 is None: |
|
return |
|
for sibling in target_h2.find_next_siblings(): |
|
sibling.extract() |
|
target_h2.extract() |
|
|
|
# Remove sections that seem to be "after the main content" |
|
remove_trailing_sections(soup, "See also") |
|
remove_trailing_sections(soup, "References") |
|
|
|
# Select the main part of a wikipedia page |
|
main_text = soup.find("div", {"class": "mw-parser-output"}) |
|
text = main_text.getText().strip() |
|
return "\n".join(result) |
|
|
|
def __send_chatgpt_msg(self, prompt: str, role: str = "user") -> Optional[str]: |
|
msg = { "role": role, "content": prompt } |
|
self.__chat_messages.append(msg) |
|
if role == "system": |
|
return "/dev/null" |
|
try: |
|
completion = openai.ChatCompletion.create( |
|
model=self.__model, |
|
messages=self.__chat_messages, |
|
) |
|
except Exception as e: |
|
import traceback |
|
traceback.print_exception(e) |
|
#print(e) |
|
return None |
|
self._add_tokens(completion) |
|
with tempfile.NamedTemporaryFile(suffix='.txt', mode="w", delete=False) as temp_file: |
|
print(completion.choices[0].message.content, file=temp_file) |
|
return temp_file.name |
|
|
|
@retry(3) |
|
def init_chat(self) -> Optional[str]: |
|
prompt = textwrap.dedent(f"""\ |
|
You are StoryAI, an AI storyteller that answers every question a kid may have. You are |
|
42 year old horse trainer from Vermont. Think step by step. |
|
|
|
If your response is a story or any kind, the response should be several |
|
paragraphs and include a caption that describes part of the story. Be creative. |
|
|
|
Example: |
|
Tell a story about a horse that leads an army |
|
Caption: Horse running across sunny field. Photorealistic. Artsy. HDR. |
|
A horse named Jager loved the open air. Every day... |
|
|
|
Example: |
|
Tell about Elsa and Ana fighting off hords of soldiers |
|
Caption: Elsa and Ana hiding in a corner. Animated. High definition. Rich colors. |
|
Once upon a time in a faraway land... |
|
|
|
If the prompt is asking for a picture, respond with a caption as well as also describing the scene. |
|
Be creatinve. A picture will be created for them. Think step by step. |
|
|
|
If the prompt is asking for information, include a summarizing fact of the question they |
|
are looking for. The response should be short and succinct. |
|
|
|
Example: |
|
What is the tallest building in the world? |
|
Fact: The tallest building in the world, very large buildings |
|
The tallest building in the world is the Burge Khalifa |
|
|
|
Example: |
|
Why did the flower in the story die? |
|
Fact: Causes of death in flowers, dehydration, sunlight |
|
The tallest building in the world is the Burge Khalifa |
|
|
|
Include facts and captions when possible. |
|
""" |
|
) |
|
return self.__send_chatgpt_msg(prompt, role="system") |
|
|
|
@retry(3) |
|
def send_chat(self, prompt: str) -> Optional[str]: |
|
return self.__send_chatgpt_msg(prompt) |
|
|
|
def __init_stable_diffusion(self): |
|
pipe = diffusers.DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") |
|
pipe = pipe.to("mps") |
|
|
|
# Recommended if your computer has < 64 GB of RAM |
|
pipe.enable_attention_slicing() |
|
self.__stable_diffusion = pipe |
|
|
|
def generate_image(self, prompt: str): |
|
image = self.__stable_diffusion(prompt).images[0] |
|
with tempfile.NamedTemporaryFile(suffix='.png', mode="wb", delete=False) as temp_file: |
|
image.save(temp_file) |
|
image_file = temp_file.name |
|
with open(image_file, "rb") as handle: |
|
imgcat.imgcat(handle) |
|
|
|
|
|
class Actions: |
|
def __init__(self, actions, modified_input): |
|
self.__actions = actions |
|
self.modified_input = modified_input |
|
self.caption = next((text for action, text in self.__actions if action == "caption"), None) |
|
self.fact = next((text for action, text in self.__actions if action == "fact"), None) |
|
|
|
@classmethod |
|
def parse_file(cls, path: str) -> 'Actions': |
|
with open(path) as handle: |
|
input_string = handle.read() |
|
|
|
# Split the input string into separate lines |
|
lines = input_string.splitlines() |
|
|
|
# Initialize lists to store occurrences and non-occurrences |
|
occurrences = [] |
|
non_occurrences = [] |
|
|
|
# Iterate through each line |
|
for line in lines: |
|
# Check if the line starts with "Caption:" (case-insensitive) |
|
if line.lower().startswith('caption:'): |
|
occurrences.append(('caption', line)) |
|
elif line.lower().startswith('fact:'): |
|
occurrences.append(('fact', line)) |
|
else: |
|
non_occurrences.append(line) |
|
|
|
# Return a tuple of the two lists |
|
return Actions(occurrences, '\n'.join(non_occurrences)) |
|
|
|
|
|
def setup_openai_key(cfg_path="~/.keys"): |
|
cfg = configparser.ConfigParser() |
|
cfg.read(os.path.expanduser(cfg_path)) |
|
openai.api_key = cfg.get("DEFAULT", "openai-key") |
|
openai.organization = cfg.get("DEFAULT", "openai-org") |
|
|
|
|
|
def print_file(file): |
|
if file is None: |
|
return |
|
print("text file:", file) |
|
with open(file) as handle: |
|
print(handle.read()) |
|
|
|
def main(): |
|
setup_openai_key() |
|
print("starting") |
|
app = Crystal(-1) |
|
print_file(app.init_chat()) |
|
#app.generate_image("a plantain eating dinner with a horse") |
|
while True: |
|
_ = input("Try again! <hit enter>") |
|
text = app.get_text(app.listen()) |
|
print(text) |
|
response = app.send_chat(text) |
|
print(app.tokens) |
|
print_file(response) |
|
app.speak(response) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
#app = Crystal(-1) |
|
#app.ddg_reprompt("foo bar baz") |