Skip to content

Instantly share code, notes, and snippets.

@WAUthethird
Last active February 13, 2024 07:10
Show Gist options
  • Save WAUthethird/658581cc08b06ac77c64181fb3bd5a86 to your computer and use it in GitHub Desktop.
Save WAUthethird/658581cc08b06ac77c64181fb3bd5a86 to your computer and use it in GitHub Desktop.
Using vast.ai with OpenAI Jukebox

Using vast.ai with OpenAI Jukebox

vast.ai is an easy-to-use and comparatively cheap service that allows users to loan GPU compute and processing power to others for a much lower hourly price than almost every other cloud provider. In this guide, I will walk you through the steps of using vast.ai to quickly generate content using OpenAI Jukebox - as the GPUs available can be much faster than the ones provided by Google Colab, while still being reasonably priced.

The .ipynb notebook is based off of this one on Colab: https://colab.research.google.com/github/SMarioMan/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb

  1. Go to vast.ai and create an account. You will also need to provide your payment info.
    • You may or may not recieve the "few minutes of trial credit" advertised in the top right - if this occurs, you can talk to them through their live chat.
  2. You will need to configure the Instance Configuration. Make sure you are looking at the "Create" page, and look to the left of the list of GPUs. Set your disk space slider to around 32GB, then press "Edit Image & Config..."
  3. You will see a variety of options. Scroll down and press the "Select" button on the other side of pytorch/pytorch. Make sure you have the jupyter-python-notebook setting enabled, and expand the dropdown menu above it. You will see quite a few Docker version tags. Select (THIS IS VERY IMPORTANT) 1.10.0-cuda-11.3-cudnn8-runtime. Press "Select" at the bottom.
  4. Before adding credit, we will need to consider your needs. How much and for how long will you use the service? Are you planning to upsample your songs or create long songs? Depending on the answer, you will need to add the approximate amount of credit you deem necessary for your use case.
    • When signed in (or even when signed out!) you will see a list of GPU-accelerated machines when you are in the "Create" tab. When a machine is greyed out, that means another person is using it, so it could become available in the future. To see a full list of available/unavailable machines, check "Include Unverified Machines" and "Include Incompatible Machines" in the "Filter offers" section. You can also sort by price as well.
    • Here is my list of recommended GPU configurations that should work with OpenAI Jukebox:
      • 1x V100 (32GB ver)
      • 1x Titan RTX
      • 1x RTX 3090
      • 1x RTX 3090 Ti
      • 1x RTX 4090
      • 1x Quadro RTX 6000
      • 1x Quadro RTX 8000
      • 1x Quadro RTX A4500
      • 1x Quadro RTX A5000 (desktop)
      • 1x Quadro RTX A5500
      • 1x Quadro RTX A6000
      • 1x A10
      • 1x A30
      • 1x A40
      • 1x A100
      • 1x H100
    • I do not recommend using 16GB of VRAM with OpenAI Jukebox. The ones above have 24GB or more of VRAM. (As for why I haven't included multi-GPU configs, this is because Jukebox does not generate on multi-GPU setups, to my knowledge, and would only utilize one GPU instead of both/multiple.)
    • Here is my list of minimum-spec (16GB) GPU configurations that should work with OpenAI Jukebox:
      • 1x RTX 4090 (mobile)
      • 1x RTX 4080
      • 1x RTX 3080 Ti (mobile)
      • 1x Quadro RTX 5000
      • 1x Tesla T4
      • 1x Titan V
      • 1x Tesla V100 (16GB ver)
      • 1x Quadro RTX 5000
      • 1x Quadro RTX A4000 (desktop)
      • 1x A2
  5. Now that you have everything set up, find an instance with a price and network speed that seems reasonable to you (you will want yours to have a good download speed as the 10GB model download will likely take a while) and press "Rent".
  6. Go to the "Instances" tab and you'll see your machine being set up! You may need to wait a few minutes for it to set everything up on its end, but eventually you'll be able to connect to it.
  7. A new tab should pop open in your browser (if it displays some sort of network error, try reloading - it may still be doing some last-minute setup) and you'll see a filesystem pop up. Ignore what's in there and upload the .ipynb file linked below. (or click this link to download, right click the page and "Save As" - it may try to append a .txt at the end of the filename, delete it before saving or after saving)
    • Note: you'll need to download this file to your local machine, and then press the "Upload" button on the filesystem page. Then, you'll need to press "Upload" next to the file itself, and wait for it to upload. Once done, you can click on the name of the newly-uploaded file and it'll open a new tab with the notebook. It should be quite similar to Google Colab. (Further instructions will be in there)

Additional Notes

  • You should not close out the notebook connection or page, even though it will still be running. I haven't figured out a way to reliably reconnect to an existing notebook kernel, so you will probably need to have it open for as long as you're using it, just like Colab.
  • If you want to run the notebook locally, you will need an Anaconda environment with Python 3.7 and the last version of PyTorch 1.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uq8uLwZCn0BV"
},
"source": [
"This notebook is modified from SMarioMan's notebook: [https://colab.research.google.com/github/SMarioMan/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb](https://colab.research.google.com/github/SMarioMan/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb)\n",
"\n",
"Welcome to your GPU instance!\n",
"Jupyter works a little differently from Colab - instead of a play button, you'll need to press Shift+Enter on a selected cell.\n",
"\n",
"Try it on the one below!\n",
"The output should display your GPU(s)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8qEqdj8u0gdN"
},
"outputs": [],
"source": [
"!nvidia-smi -L"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zy4Rehq9ZKv_"
},
"source": [
"Next, we'll install some packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!apt update\n",
"!apt install apt-utils -y\n",
"!apt upgrade -y\n",
"!apt install git wget -y\n",
"!apt install libopenmpi-dev libsndfile1-dev -y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll now download the Jukebox code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sAdFGF-bqVMY"
},
"outputs": [],
"source": [
"!pip install git+https://github.com/openai/jukebox.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "taDHgk1WCC_C"
},
"outputs": [],
"source": [
"import jukebox\n",
"import torch as t\n",
"import librosa\n",
"import os\n",
"from IPython.display import Audio\n",
"from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model\n",
"from jukebox.hparams import Hyperparams, setup_hparams\n",
"from jukebox.sample import sample_single_window, _sample, \\\n",
" sample_partial_window, upsample, \\\n",
" load_prompts\n",
"from jukebox.utils.dist_utils import setup_dist_from_mpi\n",
"from jukebox.utils.torch_utils import empty_cache\n",
"rank, local_rank, device = setup_dist_from_mpi()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From here, proceed as you would normally. If you want to use co-composing, scroll all the way down to the start of that section NOW - do not run any cells in the section directly below!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "89FftI5kc-Az"
},
"source": [
"# Sample from the 5B or 1B Lyrics Model\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "65aR2OZxmfzq"
},
"outputs": [],
"source": [
"model = '5b_lyrics' # or '5b' or '1b_lyrics'\n",
"hps = Hyperparams()\n",
"hps.sr = 44100\n",
"hps.n_samples = 3 if model in ('5b', '5b_lyrics') else 8\n",
"# Specifies the directory to save the sample in.\n",
"hps.name = '/workspace/samples'\n",
"chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n",
"max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n",
"hps.levels = 3\n",
"hps.hop_fraction = [.5,.5,.125]\n",
"\n",
"vqvae, *priors = MODELS[model]\n",
"vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)\n",
"top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rvf-5pnjbmI1"
},
"source": [
"# Select mode\n",
"Run one of these cells to select the desired mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VVOQ3egdj65y"
},
"outputs": [],
"source": [
"# The default mode of operation.\n",
"# Creates songs based on artist and genre conditioning.\n",
"mode = 'ancestral'\n",
"codes_file=None\n",
"audio_file=None\n",
"prompt_length_in_seconds=None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vqqv2rJKkMXd"
},
"outputs": [],
"source": [
"# Prime song creation using an arbitrary audio sample.\n",
"mode = 'primed'\n",
"codes_file=None\n",
"# Specify an audio file here.\n",
"audio_file = '/workspace/primer.wav'\n",
"# Specify how many seconds of audio to prime on.\n",
"prompt_length_in_seconds=12"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxZMi-S3cT2b"
},
"source": [
"Run this cell to automatically resume from the latest checkpoint file, but only if the checkpoint file exists.\n",
"This will override the selected mode.\n",
"We will assume the existence of a checkpoint means generation is complete and it's time for upsamping to occur."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GjRwyTDhbvf-"
},
"outputs": [],
"source": [
"if os.path.exists(hps.name):\n",
" # Identify the lowest level generated and continue from there.\n",
" for level in [1, 2]:\n",
" data = f\"{hps.name}/level_{level}/data.pth.tar\"\n",
" if os.path.isfile(data):\n",
" mode = 'upsample'\n",
" codes_file = data\n",
" print('Upsampling from level '+str(level))\n",
" break\n",
"print('mode is now '+mode)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UA2UhOZ4YfZj"
},
"source": [
"Run the cell below regardless of which mode you chose."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jp7nKnCmk1bx"
},
"outputs": [],
"source": [
"sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JYKiwkzy0Iyf"
},
"source": [
"Specify your choice of artist, genre, lyrics, and length of musical sample. \n",
"\n",
"IMPORTANT: The sample length is crucial for how long your sample takes to generate. Generating a shorter sample takes less time. A 50 second sample should be short enough to fully generate after 12 hours of processing. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-sY9aGHcZP-u"
},
"outputs": [],
"source": [
"sample_length_in_seconds = 50 # Full length of musical sample to generate - we find songs in the 1 to 4 minute\n",
" # range work well, with generation time proportional to sample length. \n",
" # This total length affects how quickly the model \n",
" # progresses through lyrics (model also generates differently\n",
" # depending on if it thinks it's in the beginning, middle, or end of sample)\n",
"hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
"assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qD0qxQeLaTR0"
},
"outputs": [],
"source": [
"# Note: Metas can contain different prompts per sample.\n",
"# By default, all samples use the same prompt.\n",
"metas = [dict(artist = \"Rick Astley\",\n",
" genre = \"Pop\",\n",
" total_length = hps.sample_length,\n",
" offset = 0,\n",
" lyrics = \"\"\"We're no strangers to love\n",
"You know the rules and so do I\n",
"A full commitment's what I'm thinking of\n",
"You wouldn't get this from any other guy\n",
"\n",
"I just wanna tell you how I'm feeling\n",
"Gotta make you understand\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"We've known each other for so long\n",
"Your heart's been aching, but\n",
"You're too shy to say it\n",
"Inside, we both know what's been going on\n",
"We know the game and we're gonna play it\n",
"\n",
"And if you ask me how I'm feeling\n",
"Don't tell me you're too blind to see\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"(Ooh, give you up)\n",
"(Ooh, give you up)\n",
"Never gonna give, never gonna give\n",
"(Give you up)\n",
"Never gonna give, never gonna give\n",
"(Give you up)\n",
"\n",
"We've known each other for so long\n",
"Your heart's been aching, but\n",
"You're too shy to say it\n",
"Inside, we both know what's been going on\n",
"We know the game and we're gonna play it\n",
"\n",
"I just wanna tell you how I'm feeling\n",
"Gotta make you understand\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\n",
"Never gonna give you up\n",
"Never gonna let you down\n",
"Never gonna run around and desert you\n",
"Never gonna make you cry\n",
"Never gonna say goodbye\n",
"Never gonna tell a lie and hurt you\n",
"\"\"\",\n",
" ),\n",
" ] * hps.n_samples\n",
"labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6PHC1XnEfV4Y"
},
"source": [
"Optionally adjust the sampling temperature (we've found .98 or .99 to be our favorite). \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eNwKyqYraTR9"
},
"outputs": [],
"source": [
"sampling_temperature = .98\n",
"\n",
"lower_batch_size = 16\n",
"max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n",
"lower_level_chunk_size = 32\n",
"chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n",
"sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,\n",
" chunk_size=lower_level_chunk_size),\n",
" dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,\n",
" chunk_size=lower_level_chunk_size),\n",
" dict(temp=sampling_temperature, fp16=True, \n",
" max_batch_size=max_batch_size, chunk_size=chunk_size)]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S3j0gT3HfrRD"
},
"source": [
"Now we're ready to sample from the model. We'll generate the top level (2) first, followed by the first upsampling (level 1), and the second upsampling (0).\n",
"\n",
"After each level, we decode to raw audio and save the audio files. \n",
"\n",
"This next cell will take a while (approximately 10 minutes per 20 seconds of music sample)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9a1tlvcVlHhN"
},
"outputs": [],
"source": [
"if sample_hps.mode == 'ancestral':\n",
" zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]\n",
" zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)\n",
"elif sample_hps.mode == 'upsample':\n",
" assert sample_hps.codes_file is not None\n",
" # Load codes.\n",
" data = t.load(sample_hps.codes_file, map_location='cpu')\n",
" zs = [z.cuda() for z in data['zs']]\n",
" assert zs[-1].shape[0] == hps.n_samples, f\"Expected bs = {hps.n_samples}, got {zs[-1].shape[0]}\"\n",
" del data\n",
" print('Falling through to the upsample step later in the notebook.')\n",
"elif sample_hps.mode == 'primed':\n",
" assert sample_hps.audio_file is not None\n",
" audio_files = sample_hps.audio_file.split(',')\n",
" duration = (int(sample_hps.prompt_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
" x = load_prompts(audio_files, duration, hps)\n",
" zs = top_prior.encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])\n",
" zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)\n",
"else:\n",
" raise ValueError(f'Unknown sample mode {sample_hps.mode}.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-gxY9aqHqfLJ"
},
"source": [
"Listen to the results from the top level (note this will sound very noisy until we do the upsampling stage). You may have more generated samples, depending on the batch size you requested."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TPZENDGZqOOb"
},
"outputs": [],
"source": [
"Audio(f'{hps.name}/level_2/item_0.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EJc3bQxmusc6"
},
"source": [
"We are now done with the large top_prior model, and instead load the upsamplers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W5VLX0zRapIm"
},
"outputs": [],
"source": [
"# Set this False if you are on a local machine that has enough memory (this allows you to do the\n",
"# lyrics alignment visualization during the upsampling stage). For a hosted runtime, \n",
"# we'll need to go ahead and delete the top_prior if you are using the 5b_lyrics model.\n",
"if True:\n",
" del top_prior\n",
" empty_cache()\n",
" top_prior=None\n",
"upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]\n",
"labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eH_jUhGDprAt"
},
"source": [
"Please note: this next upsampling step will take several hours. As the upsampling is completed, samples will appear in the vast workspace (You can access this by going back to the filesystem tab or by pressing \"Connect\" on your instance) Level 1 is the partially upsampled version, and then Level 0 is fully completed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9lkJgLolpZ6w"
},
"outputs": [],
"source": [
"zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3SJgBYJPri55"
},
"source": [
"Listen to your final sample!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2ip2PPE0rgAb"
},
"outputs": [],
"source": [
"Audio(f'{hps.name}/level_0/item_0.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8JAgFxytwrLG"
},
"outputs": [],
"source": [
"del upsamplers\n",
"empty_cache()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LpvvFH85bbBC"
},
"source": [
"# Co-Composing with the 5B or 1B Lyrics Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nFDROuS7gFQY"
},
"source": [
"For more control over the generations, try co-composing with either the 5B or 1B Lyrics Models. Again, specify your artist, genre, and lyrics. However, now instead of generating the entire sample, the model will return 3 short options for the opening of the piece (or up to 16 options if you use the 1B model instead). Choose your favorite, and then continue the loop, for as long as you like. Throughout these steps, you'll be listening to the audio at the top prior level, which means it will sound quite noisy. When you are satisfied with your co-creation, continue on through the upsampling section. This will render the piece in higher audio quality."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3y-q8ifhGBlU"
},
"outputs": [],
"source": [
"model = \"5b_lyrics\" # or \"1b_lyrics\"\n",
"hps = Hyperparams()\n",
"hps.sr = 44100\n",
"hps.n_samples = 3 if model in ('5b', '5b_lyrics') else 16\n",
"# Specifies the directory to save the sample in..\n",
"hps.name = '/workspace/co_composer'\n",
"hps.sample_length = 1048576 if model in ('5b', '5b_lyrics') else 786432 \n",
"chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n",
"max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n",
"hps.hop_fraction = [.5, .5, .125] \n",
"hps.levels = 3\n",
"\n",
"vqvae, *priors = MODELS[model]\n",
"vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = hps.sample_length)), device)\n",
"top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0X2oSu0ZASQU"
},
"source": [
"# Select mode\n",
"Run one of these cells to select the desired mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Q6PYFDfwASQU"
},
"outputs": [],
"source": [
"# The default mode of operation.\n",
"# Creates songs based on artist and genre conditioning.\n",
"mode = 'ancestral'\n",
"codes_file=None\n",
"audio_file=None\n",
"prompt_length_in_seconds=None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MDiJsmjxASQU"
},
"outputs": [],
"source": [
"# Prime song creation using an arbitrary audio sample.\n",
"mode = 'primed'\n",
"codes_file=None\n",
"# Specify an audio file here.\n",
"audio_file = '/workspace/primer.wav'\n",
"# Specify how many seconds of audio to prime on.\n",
"prompt_length_in_seconds=12"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yRNA5smyASQU"
},
"source": [
"Run the cell below regardless of which mode you chose."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHftoOb8ASQU"
},
"outputs": [],
"source": [
"sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "68hz4x7igq0c"
},
"source": [
"Specify your choice of artist, genre, lyrics, and length of musical sample. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1tyHt44PASQU"
},
"outputs": [],
"source": [
"sample_length_in_seconds = 71 # Full length of musical sample to generate - we find songs in the 1 to 4 minute\n",
" # range work well, with generation time proportional to sample length. \n",
" # This total length affects how quickly the model \n",
" # progresses through lyrics (model also generates differently\n",
" # depending on if it thinks it's in the beginning, middle, or end of sample)\n",
"hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
"assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QDMvH_1zUHo6"
},
"outputs": [],
"source": [
"metas = [dict(artist = \"Zac Brown Band\",\n",
" genre = \"Country\",\n",
" total_length = hps.sample_length,\n",
" offset = 0,\n",
" lyrics = \"\"\"I met a traveller from an antique land,\n",
" Who said—“Two vast and trunkless legs of stone\n",
" Stand in the desert. . . . Near them, on the sand,\n",
" Half sunk a shattered visage lies, whose frown,\n",
" And wrinkled lip, and sneer of cold command,\n",
" Tell that its sculptor well those passions read\n",
" Which yet survive, stamped on these lifeless things,\n",
" The hand that mocked them, and the heart that fed;\n",
" And on the pedestal, these words appear:\n",
" My name is Ozymandias, King of Kings;\n",
" Look on my Works, ye Mighty, and despair!\n",
" Nothing beside remains. Round the decay\n",
" Of that colossal Wreck, boundless and bare\n",
" The lone and level sands stretch far away\n",
" \"\"\",\n",
" ),\n",
" ] * hps.n_samples\n",
"labels = top_prior.labeller.get_batch_labels(metas, 'cuda')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B9onZMEXh34f"
},
"source": [
"## Generate 3 options for the start of the song\n",
"\n",
"Initial generation is set to be 4 seconds long, but feel free to change this"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c6peEj8I_HHO"
},
"outputs": [],
"source": [
"def seconds_to_tokens(sec, sr, prior, chunk_size):\n",
" tokens = sec * hps.sr // prior.raw_to_tokens\n",
" tokens = ((tokens // chunk_size) + 1) * chunk_size\n",
" assert tokens <= prior.n_ctx, 'Choose a shorter generation length to stay within the top prior context'\n",
" return tokens"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2gn2GXt3zt3y"
},
"outputs": [],
"source": [
"initial_generation_in_seconds = 4\n",
"tokens_to_sample = seconds_to_tokens(initial_generation_in_seconds, hps.sr, top_prior, chunk_size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U0zcWcMoiigl"
},
"source": [
"Change the sampling temperature if you like (higher is more random). Our favorite is in the range .98 to .995"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NHbH68H7VMeO"
},
"outputs": [],
"source": [
"sampling_temperature = .98\n",
"\n",
"lower_batch_size = 16\n",
"max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n",
"lower_level_chunk_size = 32\n",
"chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n",
"sampling_kwargs = dict(temp=sampling_temperature, fp16=True, max_batch_size=lower_batch_size,\n",
" chunk_size=lower_level_chunk_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JGZEPe-WTt4g"
},
"outputs": [],
"source": [
"if sample_hps.mode == 'ancestral':\n",
" zs=[t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(3)]\n",
" zs=sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
"elif sample_hps.mode == 'primed':\n",
" assert sample_hps.audio_file is not None\n",
" audio_files = sample_hps.audio_file.split(',')\n",
" duration = (int(sample_hps.prompt_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
" x = load_prompts(audio_files, duration, hps)\n",
" zs = top_prior.encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])\n",
" zs = sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
"x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mveN4Be8jK2J"
},
"source": [
"Listen to your generated samples, and then pick a favorite. If you don't like any, go back and rerun the cell above. \n",
"\n",
"** NOTE this is at the noisy top level, upsample fully (in the next section) to hear the final audio version"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LrJSGMhUOhZg"
},
"outputs": [],
"source": [
"for i in range(hps.n_samples):\n",
" librosa.output.write_wav(f'noisy_top_level_generation_{i}.wav', x[i], sr=44100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rQ4ersQ5OhZr"
},
"outputs": [],
"source": [
"Audio('noisy_top_level_generation_0.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-GdqzrGkOhZv"
},
"outputs": [],
"source": [
"Audio('noisy_top_level_generation_1.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gE5S8hyZOhZy"
},
"outputs": [],
"source": [
"Audio('noisy_top_level_generation_2.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t2-mEJaqZfuS"
},
"source": [
"If you don't like any of the options, return a few cells back to \"Sample a few options...\" and rerun from there."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o7CzSiv0MmFP"
},
"source": [
"## Choose your favorite sample and request longer generation\n",
"\n",
"---\n",
"\n",
"(Repeat from here)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "j_XFtVi99CIY"
},
"outputs": [],
"source": [
"my_choice=0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Pgk3sHHBLYoq"
},
"outputs": [],
"source": [
"zs[2]=zs[2][my_choice].repeat(hps.n_samples,1)\n",
"t.save(zs, 'zs-checkpoint2.t')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W8Rd9xxm565S"
},
"outputs": [],
"source": [
"# Set to True to load the previous checkpoint:\n",
"if False:\n",
" zs=t.load('zs-checkpoint2.t') "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k12xjMgHkRGP"
},
"source": [
"Choose the length of the continuation. The 1B model can generate up to 17 second samples and the 5B up to 23 seconds, but you'll want to pick a shorter continuation length so that it will be able to look back at what you've generated already. Here we've chosen 4 seconds."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "h3_-0a07kHHG"
},
"outputs": [],
"source": [
"continue_generation_in_seconds=4\n",
"tokens_to_sample = seconds_to_tokens(continue_generation_in_seconds, hps.sr, top_prior, chunk_size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GpPG3Ifqk8ue"
},
"source": [
"The next step asks the top prior to generate more of the sample. It'll take up to a few minutes, depending on the sample length you request."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YoHkeSTaEyLj"
},
"outputs": [],
"source": [
"zs = sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
"x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ymhUqEdhleEi"
},
"source": [
"Now listen to the longer versions of the sample you selected, and again choose a favorite sample. If you don't like any, return back to the cell where you can load the checkpoint, and continue again from there.\n",
"\n",
"When the samples start getting long, you might not always want to listen from the start, so change the playback start time later on if you like."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2H1LNLTa_R6a"
},
"outputs": [],
"source": [
"playback_start_time_in_seconds = 0 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "r4SBGAmsnJtH"
},
"outputs": [],
"source": [
"for i in range(hps.n_samples):\n",
" librosa.output.write_wav(f'top_level_continuation_{i}.wav', x[i][playback_start_time_in_seconds*44100:], sr=44100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2WeyE5Qtnmeo"
},
"outputs": [],
"source": [
"Audio('top_level_continuation_0.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BKtfEtcaazXE"
},
"outputs": [],
"source": [
"Audio('top_level_continuation_1.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yrlS0XwK2S0"
},
"outputs": [],
"source": [
"Audio('top_level_continuation_2.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-OJT704dvnGv"
},
"source": [
"To make a longer song, return back to \"Choose your favorite sample\" and loop through that again"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzCrkCZJvUcQ"
},
"source": [
"# Upsample Co-Composition to Higher Audio Quality"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4MPgukwMmB0p"
},
"source": [
"Choose your favorite sample from your latest group of generations. (If you haven't already gone through the Co-Composition block, make sure to do that first so you have a generation to upsample)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yv-pNNPHBQYC"
},
"outputs": [],
"source": [
"choice = 0\n",
"select_best_sample = True # Set false if you want to upsample all your samples \n",
" # upsampling sometimes yields subtly different results on multiple runs,\n",
" # so this way you can choose your favorite upsampling"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v17cEAqyCgfo"
},
"outputs": [],
"source": [
"if select_best_sample:\n",
" zs[2]=zs[2][choice].repeat(zs[2].shape[0],1)\n",
"\n",
"t.save(zs, 'zs-top-level-final.t')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qqlR9368s3jJ"
},
"outputs": [],
"source": [
"if False:\n",
" zs = t.load('zs-top-level-final.t')\n",
"\n",
"assert zs[2].shape[1]>=2048, f'Please first generate at least 2048 tokens at the top level, currently you have {zs[2].shape[1]}'\n",
"hps.sample_length = zs[2].shape[1]*top_prior.raw_to_tokens"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jzHwF_iqgIWM"
},
"outputs": [],
"source": [
"# Set this False if your instance has 24GB or more of VRAM.\n",
"if True:\n",
" del top_prior\n",
" empty_cache()\n",
" top_prior=None\n",
"\n",
"upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q22Ier6YSkKS"
},
"outputs": [],
"source": [
"sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=16, chunk_size=32),\n",
" dict(temp=0.99, fp16=True, max_batch_size=16, chunk_size=32),\n",
" dict(fp16=True)]\n",
"\n",
"if type(labels)==dict:\n",
" labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] + [labels] "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T1MCa9_jnjpf"
},
"source": [
"This next step upsamples 2 levels. The level_1 samples will be available after around one hour (depending on the length of your sample) and are saved under {hps.name}/level_0/item_0.wav, while the fully upsampled level_0 will likely take 4-12 hours. You can access the wav files down below, or using the file browser."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NcNT5qIRMmHq"
},
"outputs": [],
"source": [
"zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W2jTYLPBc29M"
},
"outputs": [],
"source": [
"Audio(f'{hps.name}/level_0/item_0.wav')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "Interacting with Jukebox",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment