Skip to content

Instantly share code, notes, and snippets.

@esnya
Created June 11, 2023 03:08
Show Gist options
  • Save esnya/b193ee5d0d4ddcaef095bab63bacc425 to your computer and use it in GitHub Desktop.
Save esnya/b193ee5d0d4ddcaef095bab63bacc425 to your computer and use it in GitHub Desktop.
audiocraftで遊んだやつ。おおむね無限につなぎつつけるやつ。
import argparse
import asyncio
from contextlib import contextmanager
from functools import cached_property
from typing import Any, Callable, Generator, Optional, Tuple
import pyaudio
import torch
from audiocraft.data.audio import audio_write
from audiocraft.models.musicgen import MusicGen
from torchaudio.backend.soundfile_backend import load
from websockets.server import WebSocketServerProtocol, serve
class MusicGenerator:
def __init__(self, model: MusicGen):
self.model = model
self.wav_tensor: Optional[torch.Tensor] = None
self.stream = None
self.description = ""
self.volume = 0.5
self.melody: Optional[Tuple[torch.Tensor, int]] = None
self.prompt_duration = 1.0
self.output_device_index: Optional[int] = None
self._stop = asyncio.Event()
@cached_property
def _commands(self) -> dict[str, Callable]:
return {
"/exit": lambda _: self._stop.set(),
"/duration": lambda input_str: self.model.set_generation_params(
duration=float(input_str.split(" ", 1)[1])
),
"/volume": lambda input_str: setattr(
self, "volume", float(input_str.split(" ", 1)[1])
),
"/melody": lambda input_str: self.set_melody(input_str.split(" ", 1)[1]),
"/clear_melody": lambda _: setattr(self, "melody", None),
"/save": lambda input_str: asyncio.to_thread(
audio_write,
input_str.split(" ", 1)[1],
self.wav_tensor.cpu(),
self.model.sample_rate,
strategy="loudness",
)
if self.wav_tensor
else None,
}
def set_melody(self, filepath: str):
self.melody = load(filepath)
@contextmanager
def _open_stream(self) -> Generator[pyaudio.Stream, Any, None]:
stream = pyaudio.PyAudio().open(
format=pyaudio.paFloat32,
channels=1,
rate=self.model.sample_rate,
output=True,
output_device_index=self.output_device_index,
)
self.stream = stream
try:
yield stream
finally:
self.stream.close()
async def _generate_loop(self):
while not self._stop.is_set():
input_str: str = await asyncio.to_thread(
input, f"{self.description or 'Description'}> "
)
command = input_str.lower().split(" ", 1)[0]
if command in self._commands:
try:
self._commands[command](input_str)
except Exception as e:
print(e)
else:
self.description = input_str.strip() or self.description
# await self._generate()
async def _generate(self):
if self.melody:
try:
melody, melody_sr = self.melody
output = await asyncio.to_thread(
self.model.generate_with_chroma,
[self.description],
melody.expand(1, -1, -1),
melody_sr,
progress=True,
)
except SystemError as e:
print(e)
self.melody = None
return await self._generate()
else:
output = await asyncio.to_thread(
self.model.generate, [self.description], progress=True
)
self.wav_tensor = output.cpu()
async def _play_loop(self, stream: pyaudio.Stream):
while stream.is_active() and not self._stop.is_set():
if self.wav_tensor is None:
await asyncio.sleep(0)
continue
await asyncio.to_thread(stream.write, self.wav_tensor.numpy().tobytes())
self._stop.set()
async def _continuation_loop(self):
while not self._stop.is_set():
if self.wav_tensor is None:
await asyncio.sleep(0)
continue
prompt_frames = int(self.prompt_duration * self.model.sample_rate)
# self.model.generation_params["remove_prompts"] = True
output = await asyncio.to_thread(
self.model.generate_continuation,
self.wav_tensor[:, :, -prompt_frames:],
self.model.sample_rate,
[self.description],
)
self.wav_tensor = output.cpu()
async def start(self):
with self._open_stream() as stream:
if self.description:
await self._generate()
await asyncio.gather(
asyncio.create_task(self._generate_loop()),
asyncio.create_task(self._play_loop(stream)),
asyncio.create_task(self._continuation_loop()),
)
def set_pad_mode_recursive(target, pad_mode: str, _done_list=set(), name="$"):
if target in _done_list:
return
_done_list.add(target)
if hasattr(target, "pad_mode"):
print(f"{name}({target}).pad_mode = {pad_mode}")
target.pad_mode = pad_mode
for key in dir(target):
child = getattr(target, key)
# print(key, child.__class__.__name__)
if isinstance(child, torch.nn.Module):
set_pad_mode_recursive(child, pad_mode, _done_list, f"{name}.{key}")
async def main():
parser = argparse.ArgumentParser(description="Jukebox powered by AudioCraft")
parser.add_argument(
"--model", default="melody", type=str, help="Name of the pretrained model."
)
parser.add_argument(
"--device", default="cuda", type=str, help="Device to use for generation."
)
parser.add_argument(
"--description",
default="Simple Music",
type=str,
help="Initial description for the music generator.",
)
parser.add_argument(
"--duration", default=15, type=float, help="Duration for music generation."
)
parser.add_argument(
"--volume", default=0.5, type=float, help="Volume to play the generated music."
)
parser.add_argument(
"--melody", default=None, type=str, help="Path to the melody to use."
)
parser.add_argument(
"--continuous-overlap",
default=15,
type=float,
help="Overlap duration for continuous generation.",
)
parser.add_argument(
"--list-audio-devices",
action="store_true",
help="List all available audio devices.",
)
parser.add_argument(
"--output-device-index",
default=None,
type=int,
help="Index of the output device to use.",
)
args = parser.parse_args()
if args.list_audio_devices:
pa = pyaudio.PyAudio()
for i in range(pa.get_device_count()):
info = pa.get_device_info_by_index(i)
if info["maxOutputChannels"] == 0:
continue
print(info)
return
model = MusicGen.get_pretrained(args.model, args.device)
model.set_generation_params(duration=args.duration)
set_pad_mode_recursive(model, "circular")
generator = MusicGenerator(model)
generator.description = args.description
generator.volume = args.volume
generator.output_device_index = args.output_device_index
if args.melody:
generator.set_melody(args.melody)
async def handle_websocket(websocket: WebSocketServerProtocol):
description = await websocket.recv()
print(f"\n{description}")
generator.description = description
async with serve(handle_websocket, "localhost", 8001):
await generator.start()
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment