Skip to content

Instantly share code, notes, and snippets.

@tam17aki
Created October 10, 2023 00:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tam17aki/e7ee04fe1b80872fe283dd3970ef615c to your computer and use it in GitHub Desktop.
Save tam17aki/e7ee04fe1b80872fe283dd3970ef615c to your computer and use it in GitHub Desktop.
# Copyright (c) 2023 Blake Mallory
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import platform
import queue
import tempfile
import threading
import time
import numpy as np
import pynput.keyboard
import speech_recognition as sr
import torch
import whisper
from whisper_mic.utils import get_logger
class WhisperMic:
def __init__(
self,
model="base",
device=("cuda" if torch.cuda.is_available() else "cpu"),
english=False,
verbose=False,
energy=300,
pause=0.8,
dynamic_energy=False,
save_file=False,
model_root="~/.cache/whisper",
mic_index=None,
):
self.logger = get_logger("whisper_mic", "info")
self.energy = energy
self.pause = pause
self.dynamic_energy = dynamic_energy
self.save_file = save_file
self.verbose = verbose
self.english = english
self.keyboard = pynput.keyboard.Controller()
self.platform = platform.system()
if self.platform == "darwin":
if device == "mps":
self.logger.warning(
"Using MPS for Mac, this does not work but may in the future"
)
device = "mps"
device = torch.device(device)
if (model != "large" and model != "large-v2") and self.english:
model = model + ".en"
self.audio_model = whisper.load_model(model, download_root=model_root).to(
device
)
self.temp_dir = tempfile.mkdtemp() if save_file else None
self.audio_queue = queue.Queue()
self.result_queue: "queue.Queue[str]" = queue.Queue()
self.break_threads = False
self.mic_active = False
self.banned_results = ["", " ", "\n", None]
self.setup_mic(mic_index)
def setup_mic(self, mic_index):
# if mic_index is None:
# # self.logger.info("No mic index provided, using default")
self.source = sr.Microphone(sample_rate=16000, device_index=mic_index)
self.recorder = sr.Recognizer()
self.recorder.energy_threshold = self.energy
self.recorder.pause_threshold = self.pause
self.recorder.dynamic_energy_threshold = self.dynamic_energy
with self.source:
self.recorder.adjust_for_ambient_noise(self.source)
self.recorder.listen_in_background(
self.source, self.record_callback, phrase_time_limit=2
)
# self.logger.info("Mic setup complete, you can now talk")
self.logger.info("<認識開始>")
def preprocess(self, data):
return torch.from_numpy(
np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0
)
def get_all_audio(self, min_time: float = -1.0):
audio = bytes()
got_audio = False
time_start = time.time()
while not got_audio or time.time() - time_start < min_time:
while not self.audio_queue.empty():
audio += self.audio_queue.get()
got_audio = True
data = sr.AudioData(audio, 16000, 2)
data = data.get_raw_data()
return data
def record_callback(self, _, audio: sr.AudioData) -> None:
data = audio.get_raw_data()
self.audio_queue.put_nowait(data)
def transcribe_forever(self) -> None:
while True:
if self.break_threads:
break
self.transcribe()
def transcribe(self, data=None, realtime: bool = False) -> None:
if data is None:
audio_data = self.get_all_audio()
else:
audio_data = data
audio_data = self.preprocess(audio_data)
if self.english:
result = self.audio_model.transcribe(audio_data, language="english")
else:
result = self.audio_model.transcribe(audio_data, language="ja", fp16=False)
predicted_text = result["text"]
if not self.verbose:
if predicted_text not in self.banned_results:
self.result_queue.put_nowait(predicted_text)
else:
if predicted_text not in self.banned_results:
self.result_queue.put_nowait(result)
if self.save_file:
os.remove(audio_data)
def listen_loop(self, dictate: bool = False) -> None:
thread_ = threading.Thread(target=self.transcribe_forever)
thread_.start()
while True:
result = self.result_queue.get()
if dictate:
self.keyboard.type(result)
else:
print(result)
if "終了" in result:
self.break_threads = True
break
self.logger.info("<認識終了>")
thread_.join() # Wait for the thread to finish
def listen(self, timeout: int = 3):
audio_data = self.get_all_audio(timeout)
self.transcribe(data=audio_data)
while True:
if not self.result_queue.empty():
return self.result_queue.get()
def toggle_microphone(self) -> None:
# TO DO: make this work
self.mic_active = not self.mic_active
if self.mic_active:
print("Mic on")
else:
print("turning off mic")
self.mic_thread.join()
print("Mic off")
if __name__ == "__main__":
mic = WhisperMic(model_root="./", model="medium")
mic.listen_loop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment