Skip to content

Instantly share code, notes, and snippets.

@WAUthethird
Last active February 14, 2024 02:13
Show Gist options
  • Save WAUthethird/e995d68e95272e06b0574b51f6da7b1c to your computer and use it in GitHub Desktop.
Save WAUthethird/e995d68e95272e06b0574b51f6da7b1c to your computer and use it in GitHub Desktop.
Local Setup for OpenAI Jukebox

Local Setup for OpenAI Jukebox

This guide is for those wanting to run OpenAI Jukebox on their own machines.

Do note that you will need a 16GB VRAM-equipped GPU or (more preferably) higher in order to utilize Jukebox to its fullest potential.

Additionally, you will also want to be using a Linux distro of some kind. WSL works too in my experience, but I've found native Linux to be the more stable and reliable option. Native Windows should in theory work (I've never gotten it to), but Windows overheads can mean that you have less available VRAM at your disposal. Jukebox is particularly heavy on that front.

The instructions in this guide and the commands in the notebook assume an Ubuntu/Debian-based system (i.e. Linux Mint), but the things you'll need to install should be available on most distros.

Installation and Setup

  1. If you don't have it already, download and install Miniconda (or Anaconda) from https://docs.anaconda.com/free/miniconda/. I recommend using the quick command line install option at the bottom, as it's very easy to get started.
  2. Once your Conda setup is running (you will likely need to restart your shell to apply changes), type conda create -n jukebox python=3.7 to create the environment we'll be using. The choice of Python 3.7 is important, as it's the most widely compatible with Jukebox.
  3. After that's finished, type conda activate jukebox to enter the new environment.
  4. Type conda install jupyterlab to install the software that will open our Jukebox notebook.
  5. Enter this command to install PyTorch: conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia. Note that Pytorch 2 is not supported by Jukebox, so we're instead installing the last version of PyTorch 1.
  6. Download the notebook below to your computer.
  7. In your command line, type jupyter-lab. This will start up the software in your default browser. Find the notebook you just downloaded in Jupyter Lab and open it. (Also note that wherever you start jupyter-lab will be the place where audio outputs are saved.)
  8. From here, you can follow the directions in the notebook by pressing Shift+Enter to run individual cells. If you run into permission errors when installing system packages within the notebook, you'll want to run those in a separate command line window with elevated permissions.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uq8uLwZCn0BV"
},
"source": [
"This notebook is modified from SMarioMan's notebook: [https://colab.research.google.com/github/SMarioMan/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb](https://colab.research.google.com/github/SMarioMan/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb)\n",
"\n",
"Welcome to your local Jukebox!\n",
"Jupyter works a little differently from Colab - instead of a play button, you'll need to press Shift+Enter on a selected cell.\n",
"\n",
"Try it on the one below!\n",
"The output should display your GPU(s)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8qEqdj8u0gdN"
},
"outputs": [],
"source": [
"!nvidia-smi -L"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zy4Rehq9ZKv_"
},
"source": [
"Next, we'll install some packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!apt update\n",
"!apt install apt-utils -y\n",
"!apt upgrade -y\n",
"!apt install git wget -y\n",
"!apt install libopenmpi-dev libsndfile1-dev -y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll now download the Jukebox code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sAdFGF-bqVMY"
},
"outputs": [],
"source": [
"!pip install git+https://github.com/openai/jukebox.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "taDHgk1WCC_C"
},
"outputs": [],
"source": [
"import jukebox\n",
"import torch as t\n",
"import librosa\n",
"import os\n",
"from IPython.display import Audio\n",
"from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model\n",
"from jukebox.hparams import Hyperparams, setup_hparams\n",
"from jukebox.sample import sample_single_window, _sample, \\\n",
" sample_partial_window, upsample, \\\n",
" load_prompts\n",
"from jukebox.utils.dist_utils import setup_dist_from_mpi\n",
"from jukebox.utils.torch_utils import empty_cache\n",
"rank, local_rank, device = setup_dist_from_mpi()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From here, proceed as you would normally. If you want to use co-composing, scroll all the way down to the start of that section NOW - do not run any cells in the section directly below!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "89FftI5kc-Az"
},
"source": [
"# Sample from the 5B or 1B Lyrics Model\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "65aR2OZxmfzq"
},
"outputs": [],
"source": [
"model = '5b_lyrics' # or '5b' or '1b_lyrics'\n",
"hps = Hyperparams()\n",
"hps.sr = 44100\n",
"hps.n_samples = 3 if model in ('5b', '5b_lyrics') else 8\n",
"# Specifies the directory to save the sample in.\n",
"hps.name = 'samples'\n",
"chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n",
"max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n",
"hps.levels = 3\n",
"hps.hop_fraction = [.5,.5,.125]\n",
"\n",
"vqvae, *priors = MODELS[model]\n",
"vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)\n",
"top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rvf-5pnjbmI1"
},
"source": [
"# Select mode\n",
"Run one of these cells to select the desired mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VVOQ3egdj65y"
},
"outputs": [],
"source": [
"# The default mode of operation.\n",
"# Creates songs based on artist and genre conditioning.\n",
"mode = 'ancestral'\n",
"codes_file=None\n",
"audio_file=None\n",
"prompt_length_in_seconds=None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vqqv2rJKkMXd"
},
"outputs": [],
"source": [
"# Prime song creation using an arbitrary audio sample.\n",
"mode = 'primed'\n",
"codes_file=None\n",
"# Specify an audio file here.\n",
"audio_file = 'primer.wav'\n",
"# Specify how many seconds of audio to prime on.\n",
"prompt_length_in_seconds=12"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxZMi-S3cT2b"
},
"source": [
"Run this cell to automatically resume from the latest checkpoint file, but only if the checkpoint file exists.\n",
"This will override the selected mode.\n",
"We will assume the existence of a checkpoint means generation is complete and it's time for upsamping to occur."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GjRwyTDhbvf-"
},
"outputs": [],
"source": [
"if os.path.exists(hps.name):\n",
" # Identify the lowest level generated and continue from there.\n",
" for level in [1, 2]:\n",
" data = f\"{hps.name}/level_{level}/data.pth.tar\"\n",
" if os.path.isfile(data):\n",
" mode = 'upsample'\n",
" codes_file = data\n",
" print('Upsampling from level '+str(level))\n",
" break\n",
"print('mode is now '+mode)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UA2UhOZ4YfZj"
},
"source": [
"Run the cell below regardless of which mode you chose."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jp7nKnCmk1bx"
},
"outputs": [],
"source": [
"sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JYKiwkzy0Iyf"
},
"source": [
"Specify your choice of artist, genre, lyrics, and length of musical sample. \n",
"\n",
"IMPORTANT: The sample length is crucial for how long your sample takes to generate. Generating a shorter sample takes less time. A 50 second sample should be short enough to fully generate after 12 hours of processing. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-sY9aGHcZP-u"
},
"outputs": [],
"source": [
"sample_length_in_seconds = 50 # Full length of musical sample to generate - we find songs in the 1 to 4 minute\n",
" # range work well, with generation time proportional to sample length. \n",
" # This total length affects how quickly the model \n",
" # progresses through lyrics (model also generates differently\n",
" # depending on if it thinks it's in the beginning, middle, or end of sample)\n",
"hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
"assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qD0qxQeLaTR0"
},
"outputs": [],
"source": [
"# Note: Metas can contain different prompts per sample.\n",
"# By default, all samples use the same prompt.\n",
"metas = [dict(artist = \"Rick Astley\",\n",
" genre = \"Pop\",\n",
" total_length = hps.sample_length,\n",
" offset = 0,\n",
" lyrics = \"\"\"We're no strangers to love\n",
"You know the rules and so do I\n",
"A full commitment's what I'm thinking of\n",
"You wouldn't get this from any other guy\n",
"\n",
"I just wanna tell you how I'm feeling\n",
"Gotta make you understand\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"We've known each other for so long\n",
"Your heart's been aching, but\n",
"You're too shy to say it\n",
"Inside, we both know what's been going on\n",
"We know the game and we're gonna play it\n",
"\n",
"And if you ask me how I'm feeling\n",
"Don't tell me you're too blind to see\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"(Ooh, give you up)\n",
"(Ooh, give you up)\n",
"Never gonna give, never gonna give\n",
"(Give you up)\n",
"Never gonna give, never gonna give\n",
"(Give you up)\n",
"\n",
"We've known each other for so long\n",
"Your heart's been aching, but\n",
"You're too shy to say it\n",
"Inside, we both know what's been going on\n",
"We know the game and we're gonna play it\n",
"\n",
"I just wanna tell you how I'm feeling\n",
"Gotta make you understand\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\"\"\",\n",
" ),\n",
" ] * hps.n_samples\n",
"labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6PHC1XnEfV4Y"
},
"source": [
"Optionally adjust the sampling temperature (we've found .98 or .99 to be our favorite). \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eNwKyqYraTR9"
},
"outputs": [],
"source": [
"sampling_temperature = .98\n",
"\n",
"lower_batch_size = 16\n",
"max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n",
"lower_level_chunk_size = 32\n",
"chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n",
"sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,\n",
" chunk_size=lower_level_chunk_size),\n",
" dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,\n",
" chunk_size=lower_level_chunk_size),\n",
" dict(temp=sampling_temperature, fp16=True, \n",
" max_batch_size=max_batch_size, chunk_size=chunk_size)]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S3j0gT3HfrRD"
},
"source": [
"Now we're ready to sample from the model. We'll generate the top level (2) first, followed by the first upsampling (level 1), and the second upsampling (0).\n",
"\n",
"After each level, we decode to raw audio and save the audio files. \n",
"\n",
"This next cell will take a while (approximately 10 minutes per 20 seconds of music sample)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9a1tlvcVlHhN"
},
"outputs": [],
"source": [
"if sample_hps.mode == 'ancestral':\n",
" zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]\n",
" zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)\n",
"elif sample_hps.mode == 'upsample':\n",
" assert sample_hps.codes_file is not None\n",
" # Load codes.\n",
" data = t.load(sample_hps.codes_file, map_location='cpu')\n",
" zs = [z.cuda() for z in data['zs']]\n",
" assert zs[-1].shape[0] == hps.n_samples, f\"Expected bs = {hps.n_samples}, got {zs[-1].shape[0]}\"\n",
" del data\n",
" print('Falling through to the upsample step later in the notebook.')\n",
"elif sample_hps.mode == 'primed':\n",
" assert sample_hps.audio_file is not None\n",
" audio_files = sample_hps.audio_file.split(',')\n",
" duration = (int(sample_hps.prompt_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
" x = load_prompts(audio_files, duration, hps)\n",
" zs = top_prior.encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])\n",
" zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)\n",
"else:\n",
" raise ValueError(f'Unknown sample mode {sample_hps.mode}.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-gxY9aqHqfLJ"
},
"source": [
"Listen to the results from the top level (note this will sound very noisy until we do the upsampling stage). You may have more generated samples, depending on the batch size you requested."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TPZENDGZqOOb"
},
"outputs": [],
"source": [
"Audio(f'{hps.name}/level_2/item_0.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EJc3bQxmusc6"
},
"source": [
"We are now done with the large top_prior model, and instead load the upsamplers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W5VLX0zRapIm"
},
"outputs": [],
"source": [
"# Set this False if you are on a local machine that has enough memory (this allows you to do the\n",
"# lyrics alignment visualization during the upsampling stage). For a hosted runtime, \n",
"# we'll need to go ahead and delete the top_prior if you are using the 5b_lyrics model.\n",
"if True:\n",
" del top_prior\n",
" empty_cache()\n",
" top_prior=None\n",
"upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]\n",
"labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eH_jUhGDprAt"
},
"source": [
"Please note: this next upsampling step will take several hours. As the upsampling is completed, samples will appear in your workspace (You can access this by going back to the filesystem tab or by pressing \"Connect\" on your instance) Level 1 is the partially upsampled version, and then Level 0 is fully completed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9lkJgLolpZ6w"
},
"outputs": [],
"source": [
"zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3SJgBYJPri55"
},
"source": [
"Listen to your final sample!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2ip2PPE0rgAb"
},
"outputs": [],
"source": [
"Audio(f'{hps.name}/level_0/item_0.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8JAgFxytwrLG"
},
"outputs": [],
"source": [
"del upsamplers\n",
"empty_cache()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LpvvFH85bbBC"
},
"source": [
"# Co-Composing with the 5B or 1B Lyrics Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nFDROuS7gFQY"
},
"source": [
"For more control over the generations, try co-composing with either the 5B or 1B Lyrics Models. Again, specify your artist, genre, and lyrics. However, now instead of generating the entire sample, the model will return 3 short options for the opening of the piece (or up to 16 options if you use the 1B model instead). Choose your favorite, and then continue the loop, for as long as you like. Throughout these steps, you'll be listening to the audio at the top prior level, which means it will sound quite noisy. When you are satisfied with your co-creation, continue on through the upsampling section. This will render the piece in higher audio quality."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3y-q8ifhGBlU"
},
"outputs": [],
"source": [
"model = \"5b_lyrics\" # or \"1b_lyrics\"\n",
"hps = Hyperparams()\n",
"hps.sr = 44100\n",
"hps.n_samples = 3 if model in ('5b', '5b_lyrics') else 16\n",
"# Specifies the directory to save the sample in..\n",
"hps.name = 'co_composer/'\n",
"hps.sample_length = 1048576 if model in ('5b', '5b_lyrics') else 786432 \n",
"chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n",
"max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n",
"hps.hop_fraction = [.5, .5, .125] \n",
"hps.levels = 3\n",
"\n",
"vqvae, *priors = MODELS[model]\n",
"vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = hps.sample_length)), device)\n",
"top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0X2oSu0ZASQU"
},
"source": [
"# Select mode\n",
"Run one of these cells to select the desired mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Q6PYFDfwASQU"
},
"outputs": [],
"source": [
"# The default mode of operation.\n",
"# Creates songs based on artist and genre conditioning.\n",
"mode = 'ancestral'\n",
"codes_file=None\n",
"audio_file=None\n",
"prompt_length_in_seconds=None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MDiJsmjxASQU"
},
"outputs": [],
"source": [
"# Prime song creation using an arbitrary audio sample.\n",
"mode = 'primed'\n",
"codes_file=None\n",
"# Specify an audio file here.\n",
"audio_file = 'primer.wav'\n",
"# Specify how many seconds of audio to prime on.\n",
"prompt_length_in_seconds=12"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yRNA5smyASQU"
},
"source": [
"Run the cell below regardless of which mode you chose."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHftoOb8ASQU"
},
"outputs": [],
"source": [
"sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "68hz4x7igq0c"
},
"source": [
"Specify your choice of artist, genre, lyrics, and length of musical sample. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1tyHt44PASQU"
},
"outputs": [],
"source": [
"sample_length_in_seconds = 71 # Full length of musical sample to generate - we find songs in the 1 to 4 minute\n",
" # range work well, with generation time proportional to sample length. \n",
" # This total length affects how quickly the model \n",
" # progresses through lyrics (model also generates differently\n",
" # depending on if it thinks it's in the beginning, middle, or end of sample)\n",
"hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
"assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QDMvH_1zUHo6"
},
"outputs": [],
"source": [
"metas = [dict(artist = \"Zac Brown Band\",\n",
" genre = \"Country\",\n",
" total_length = hps.sample_length,\n",
" offset = 0,\n",
" lyrics = \"\"\"I met a traveller from an antique land,\n",
" Who said—“Two vast and trunkless legs of stone\n",
" Stand in the desert. . . . Near them, on the sand,\n",
" Half sunk a shattered visage lies, whose frown,\n",
" And wrinkled lip, and sneer of cold command,\n",
" Tell that its sculptor well those passions read\n",
" Which yet survive, stamped on these lifeless things,\n",
" The hand that mocked them, and the heart that fed;\n",
" And on the pedestal, these words appear:\n",
" My name is Ozymandias, King of Kings;\n",
" Look on my Works, ye Mighty, and despair!\n",
" Nothing beside remains. Round the decay\n",
" Of that colossal Wreck, boundless and bare\n",
" The lone and level sands stretch far away\n",
" \"\"\",\n",
" ),\n",
" ] * hps.n_samples\n",
"labels = top_prior.labeller.get_batch_labels(metas, 'cuda')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B9onZMEXh34f"
},
"source": [
"## Generate 3 options for the start of the song\n",
"\n",
"Initial generation is set to be 4 seconds long, but feel free to change this"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c6peEj8I_HHO"
},
"outputs": [],
"source": [
"def seconds_to_tokens(sec, sr, prior, chunk_size):\n",
" tokens = sec * hps.sr // prior.raw_to_tokens\n",
" tokens = ((tokens // chunk_size) + 1) * chunk_size\n",
" assert tokens <= prior.n_ctx, 'Choose a shorter generation length to stay within the top prior context'\n",
" return tokens"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2gn2GXt3zt3y"
},
"outputs": [],
"source": [
"initial_generation_in_seconds = 4\n",
"tokens_to_sample = seconds_to_tokens(initial_generation_in_seconds, hps.sr, top_prior, chunk_size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U0zcWcMoiigl"
},
"source": [
"Change the sampling temperature if you like (higher is more random). Our favorite is in the range .98 to .995"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NHbH68H7VMeO"
},
"outputs": [],
"source": [
"sampling_temperature = .98\n",
"\n",
"lower_batch_size = 16\n",
"max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n",
"lower_level_chunk_size = 32\n",
"chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n",
"sampling_kwargs = dict(temp=sampling_temperature, fp16=True, max_batch_size=lower_batch_size,\n",
" chunk_size=lower_level_chunk_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JGZEPe-WTt4g"
},
"outputs": [],
"source": [
"if sample_hps.mode == 'ancestral':\n",
" zs=[t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(3)]\n",
" zs=sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
"elif sample_hps.mode == 'primed':\n",
" assert sample_hps.audio_file is not None\n",
" audio_files = sample_hps.audio_file.split(',')\n",
" duration = (int(sample_hps.prompt_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
" x = load_prompts(audio_files, duration, hps)\n",
" zs = top_prior.encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])\n",
" zs = sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
"x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mveN4Be8jK2J"
},
"source": [
"Listen to your generated samples, and then pick a favorite. If you don't like any, go back and rerun the cell above. \n",
"\n",
"** NOTE this is at the noisy top level, upsample fully (in the next section) to hear the final audio version"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LrJSGMhUOhZg"
},
"outputs": [],
"source": [
"for i in range(hps.n_samples):\n",
" librosa.output.write_wav(f'noisy_top_level_generation_{i}.wav', x[i], sr=44100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rQ4ersQ5OhZr"
},
"outputs": [],
"source": [
"Audio('noisy_top_level_generation_0.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-GdqzrGkOhZv"
},
"outputs": [],
"source": [
"Audio('noisy_top_level_generation_1.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gE5S8hyZOhZy"
},
"outputs": [],
"source": [
"Audio('noisy_top_level_generation_2.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t2-mEJaqZfuS"
},
"source": [
"If you don't like any of the options, return a few cells back to \"Sample a few options...\" and rerun from there."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o7CzSiv0MmFP"
},
"source": [
"## Choose your favorite sample and request longer generation\n",
"\n",
"---\n",
"\n",
"(Repeat from here)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "j_XFtVi99CIY"
},
"outputs": [],
"source": [
"my_choice=0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Pgk3sHHBLYoq"
},
"outputs": [],
"source": [
"zs[2]=zs[2][my_choice].repeat(hps.n_samples,1)\n",
"t.save(zs, 'zs-checkpoint2.t')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W8Rd9xxm565S"
},
"outputs": [],
"source": [
"# Set to True to load the previous checkpoint:\n",
"if False:\n",
" zs=t.load('zs-checkpoint2.t') "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k12xjMgHkRGP"
},
"source": [
"Choose the length of the continuation. The 1B model can generate up to 17 second samples and the 5B up to 23 seconds, but you'll want to pick a shorter continuation length so that it will be able to look back at what you've generated already. Here we've chosen 4 seconds."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "h3_-0a07kHHG"
},
"outputs": [],
"source": [
"continue_generation_in_seconds=4\n",
"tokens_to_sample = seconds_to_tokens(continue_generation_in_seconds, hps.sr, top_prior, chunk_size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GpPG3Ifqk8ue"
},
"source": [
"The next step asks the top prior to generate more of the sample. It'll take up to a few minutes, depending on the sample length you request."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YoHkeSTaEyLj"
},
"outputs": [],
"source": [
"zs = sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
"x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ymhUqEdhleEi"
},
"source": [
"Now listen to the longer versions of the sample you selected, and again choose a favorite sample. If you don't like any, return back to the cell where you can load the checkpoint, and continue again from there.\n",
"\n",
"When the samples start getting long, you might not always want to listen from the start, so change the playback start time later on if you like."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2H1LNLTa_R6a"
},
"outputs": [],
"source": [
"playback_start_time_in_seconds = 0 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "r4SBGAmsnJtH"
},
"outputs": [],
"source": [
"for i in range(hps.n_samples):\n",
" librosa.output.write_wav(f'top_level_continuation_{i}.wav', x[i][playback_start_time_in_seconds*44100:], sr=44100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2WeyE5Qtnmeo"
},
"outputs": [],
"source": [
"Audio('top_level_continuation_0.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BKtfEtcaazXE"
},
"outputs": [],
"source": [
"Audio('top_level_continuation_1.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yrlS0XwK2S0"
},
"outputs": [],
"source": [
"Audio('top_level_continuation_2.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-OJT704dvnGv"
},
"source": [
"To make a longer song, return back to \"Choose your favorite sample\" and loop through that again"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzCrkCZJvUcQ"
},
"source": [
"# Upsample Co-Composition to Higher Audio Quality"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4MPgukwMmB0p"
},
"source": [
"Choose your favorite sample from your latest group of generations. (If you haven't already gone through the Co-Composition block, make sure to do that first so you have a generation to upsample)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yv-pNNPHBQYC"
},
"outputs": [],
"source": [
"choice = 0\n",
"select_best_sample = True # Set false if you want to upsample all your samples \n",
" # upsampling sometimes yields subtly different results on multiple runs,\n",
" # so this way you can choose your favorite upsampling"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v17cEAqyCgfo"
},
"outputs": [],
"source": [
"if select_best_sample:\n",
" zs[2]=zs[2][choice].repeat(zs[2].shape[0],1)\n",
"\n",
"t.save(zs, 'zs-top-level-final.t')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qqlR9368s3jJ"
},
"outputs": [],
"source": [
"if False:\n",
" zs = t.load('zs-top-level-final.t')\n",
"\n",
"assert zs[2].shape[1]>=2048, f'Please first generate at least 2048 tokens at the top level, currently you have {zs[2].shape[1]}'\n",
"hps.sample_length = zs[2].shape[1]*top_prior.raw_to_tokens"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jzHwF_iqgIWM"
},
"outputs": [],
"source": [
"# Set this False if your instance has 24GB or more of VRAM.\n",
"if True:\n",
" del top_prior\n",
" empty_cache()\n",
" top_prior=None\n",
"\n",
"upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q22Ier6YSkKS"
},
"outputs": [],
"source": [
"sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=16, chunk_size=32),\n",
" dict(temp=0.99, fp16=True, max_batch_size=16, chunk_size=32),\n",
" dict(fp16=True)]\n",
"\n",
"if type(labels)==dict:\n",
" labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] + [labels] "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T1MCa9_jnjpf"
},
"source": [
"This next step upsamples 2 levels. The level_1 samples will be available after around one hour (depending on the length of your sample) and are saved under {hps.name}/level_0/item_0.wav, while the fully upsampled level_0 will likely take 4-12 hours. You can access the wav files down below, or using the file browser."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NcNT5qIRMmHq"
},
"outputs": [],
"source": [
"zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W2jTYLPBc29M"
},
"outputs": [],
"source": [
"Audio(f'{hps.name}/level_0/item_0.wav')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "Interacting with Jukebox",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment