Skip to content

Instantly share code, notes, and snippets.

@davda54
Created March 11, 2020 16:06
Show Gist options
  • Save davda54/aa555c011866392c32c4906f8a709682 to your computer and use it in GitHub Desktop.
Save davda54/aa555c011866392c32c4906f8a709682 to your computer and use it in GitHub Desktop.
Meta-Tasnet – stereo inference
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "youtube_infererence.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ji80NTs-o34T",
"colab_type": "text"
},
"source": [
"# **META-TASNET**\n",
"## Stereo separation of full-length songs – sample code\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1acunoJ0fdg0",
"colab_type": "text"
},
"source": [
"### 1. Initialize"
]
},
{
"cell_type": "code",
"metadata": {
"id": "WPcaJd03acLe",
"colab_type": "code",
"colab": {}
},
"source": [
"!pip install youtube-dl\n",
"!pip install soundfile\n",
"!git clone https://github.com/pfnet-research/meta-tasnet\n",
"\n",
"!wget \"https://www.dropbox.com/s/zw6zgt3edd88v87/best_model.pt\"\n",
"\n",
"import youtube_dl, soundfile, librosa, os, sys, torch, IPython.display\n",
"import numpy as np\n",
"from IPython.display import HTML\n",
"from google.colab import output, files\n",
"\n",
"sys.path.append(\"/content/meta-tasnet\")\n",
"from model.tasnet import MultiTasNet\n",
"\n",
"output.clear()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "zScNng1Rff3Z",
"colab_type": "text"
},
"source": [
"### 2. Load the saved model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "KvGoD3JndIMl",
"colab_type": "code",
"outputId": "db575725-32ad-4159-a84d-11eadcc71980",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"state = torch.load(\"best_model.pt\") # load checkpoint\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") # optionally use the GPU\n",
"\n",
"network = MultiTasNet(state[\"args\"]).to(device) # initialize the model\n",
"network.load_state_dict(state['state_dict']) # load weights from the checkpoint"
],
"execution_count": 2,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"metadata": {
"tags": []
},
"execution_count": 2
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YgDm8ehsfkCK",
"colab_type": "text"
},
"source": [
"### 3. Define the separation procedure"
]
},
{
"cell_type": "code",
"metadata": {
"id": "49WcEMxOdeOI",
"colab_type": "code",
"colab": {}
},
"source": [
"# separate an audio clip (shape: [2, T]) with samping rate $rate\n",
"def separate_sample(audio, rate: int):\n",
"\n",
" audio = audio.astype('float32')\n",
" mix = [librosa.core.resample(audio, rate, s, res_type='kaiser_best', fix=False) for s in[8000, 16000, 32000]]\n",
" mix = [librosa.util.fix_length(m, (mix[0].shape[-1]+1)*(2**i)) for i, m in enumerate(mix)]\n",
" mix = [torch.from_numpy(s).float().to(device).unsqueeze_(1) for s in mix]\n",
" mix = [s / s.std(dim=-1, keepdim=True) for s in mix]\n",
"\n",
" mix_left = [s[0:1, :, :] for s in mix]\n",
" mix_right = [s[1:2, :, :] for s in mix]\n",
" del mix\n",
"\n",
" network.eval()\n",
" with torch.no_grad():\n",
" separation_left = network.inference(mix_left, n_chunks=8)[-1].cpu().squeeze_(2) # shape: (5, T)\n",
" separation_right = network.inference(mix_right, n_chunks=8)[-1].cpu().squeeze_(2) # shape: (5, T)\n",
"\n",
" separation = torch.cat([separation_left, separation_right], 0).numpy()\n",
"\n",
" estimates = {\n",
" 'drums': librosa.core.resample(separation[:, 0, :], 32000, rate, res_type='kaiser_best', fix=True)[:, :audio.shape[1]].T,\n",
" 'bass': librosa.core.resample(separation[:, 1, :], 32000, rate, res_type='kaiser_best', fix=True)[:, :audio.shape[1]].T,\n",
" 'other': librosa.core.resample(separation[:, 2, :], 32000, rate, res_type='kaiser_best', fix=True)[:, :audio.shape[1]].T,\n",
" 'vocals': librosa.core.resample(separation[:, 3, :], 32000, rate, res_type='kaiser_best', fix=True)[:, :audio.shape[1]].T,\n",
" }\n",
"\n",
" a_l = np.array([estimates['drums'][:, 0], estimates['bass'][:, 0], estimates['other'][:, 0], estimates['vocals'][:, 0]]).T\n",
" a_r = np.array([estimates['drums'][:, 1], estimates['bass'][:, 1], estimates['other'][:, 1], estimates['vocals'][:, 1]]).T\n",
" \n",
" b_l = audio[0, :]\n",
" b_r = audio[1, :]\n",
"\n",
" sol_l = np.linalg.lstsq(a_l, b_l, rcond=None)[0]\n",
" sol_r = np.linalg.lstsq(a_r, b_r, rcond=None)[0]\n",
"\n",
" e_l = a_l * sol_l\n",
" e_r = a_r * sol_r\n",
"\n",
" separation = np.array([e_l, e_r]) # shape: (channel, time, instrument)\n",
"\n",
" estimates = {\n",
" 'drums': separation[:, :, 0],\n",
" 'bass': separation[:, :, 1],\n",
" 'other': separation[:, :, 2],\n",
" 'vocals': separation[:, :, 3],\n",
" }\n",
"\n",
" return estimates"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "MIfS3Qlmfo2B",
"colab_type": "text"
},
"source": [
"### 4. Load a song from youtube\n",
"\n",
"Choose a song to separate within a time interval and hit play to load that song from youtube.\n",
"\n",
"Note that you can override the variable `id` and separate whatever song you want (use the \"id\" string from the URL address on youtube)! Just keep in mind that the audio should be of high quality (which isn't always the case on youtube, unfortunately)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "vgL36Jjjdv_J",
"colab_type": "code",
"cellView": "both",
"outputId": "b820f5a2-88f2-42ce-d7de-f4aea35cde74",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
},
"source": [
"ids = {\n",
" \"Dire Straits - Sultans Of Swing (rock)\": \"0fAQhSRLQnM\", \n",
" \"Billie Eilish - Bad Guy (pop)\": \"DyDfgMOUjCI\",\n",
" \"Death - Spirit Crusher (death metal)\": \"4_rYk_aJbcQ\",\n",
" \"James Brown - Get Up (I Feel Like Being A) Sex Machine (funk)\": \"kwjHpi4rXb8\",\n",
" \"Věra Bílá & Kale - Pas o panori (world music)\": \"R-L477kx8LA\",\n",
" \"Eminem - Lose Yourself (hip-hop)\": \"nPA2czkOsFE\",\n",
" \"Sting - Englishman in New York (pop/rock)\": \"d27gTrPPAyk\",\n",
" \"R.E.M. - Losing my Religion (alternative rock)\": \"xwtdhWltSIg\",\n",
" \"AURORA – Animal (pop)\": \"3DIT8Y3LC6M\",\n",
" \"Red Hot Chili Peppers - Scar Tissue (alternative rock)\": \"mzJj5-lubeM\",\n",
" \"John Mayer - Gravity (blues)\": \"7VBex8zbDRs\",\n",
" \"Darude - Sandstorm (EDM)\": \"y6120QOlsfU\",\n",
" \"Pokemon (soundtrack)\": \"JuYeHPFR3f0\",\n",
" \"Daft Punk - Get Lucky (pop)\": \"5NV6Rdv1a3I\",\n",
" \"Maroon 5 feat. Christina Aguilera - Moves Like Jagger (pop)\": \"suRsxpoAc5w\"\n",
"}\n",
"\n",
"song = \"Daft Punk - Get Lucky (pop)\" #@param [\"Red Hot Chili Peppers - Scar Tissue (alternative rock)\", \"Daft Punk - Get Lucky (pop)\", 'Billie Eilish - Bad Guy (pop)','Death - Spirit Crusher (death metal)', \"James Brown - Get Up (I Feel Like Being A) Sex Machine (funk)\", 'Věra Bílá & Kale - Pas o panori (world music)','Eminem - Lose Yourself (hip-hop)','Sting - Englishman in New York (pop/rock)','R.E.M. - Losing my Religion (alternative rock)', \"AURORA – Animal (pop)\", 'Dire Straits - Sultans Of Swing (rock)', \"John Mayer - Gravity (blues)\", \"Darude - Sandstorm (EDM)\", \"Pokemon (soundtrack)\", \"Maroon 5 feat. Christina Aguilera - Moves Like Jagger (pop)\"]\n",
"\n",
"id = ids[song] # change this for you own song\n",
"\n",
"ydl_opts = {\n",
" 'format': 'bestaudio/best', \n",
" 'postprocessors': [{'key': 'FFmpegExtractAudio','preferredcodec': 'wav','preferredquality': '44100',}],\n",
" 'outtmpl': 'tmp.wav'\n",
"}\n",
"with youtube_dl.YoutubeDL(ydl_opts) as ydl:\n",
" status = ydl.download([id])\n",
"\n",
"audio, rate = librosa.load('tmp.wav', mono=False, sr=None)\n",
"os.remove('tmp.wav')\n",
"\n",
"output.clear()\n",
"print(f\"shape: {audio.shape}, sampling rate: {rate} Hz, length: {audio.shape[1] / rate:.2f} s\")\n",
"print(f\"{song}\")"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"text": [
"shape: (2, 11934998), sampling rate: 48000 Hz, length: 248.65 s\n",
"Daft Punk - Get Lucky (pop)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gVnKT9I8gJbB",
"colab_type": "text"
},
"source": [
"### 5. Separate!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kcROr5atXtNE",
"colab_type": "code",
"outputId": "21525c98-9703-4d06-de90-ad7a8088041a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
},
"source": [
"print(\"separating... \", end='')\n",
"estimates = separate_sample(audio, rate)\n",
"print(\"done\")\n",
"\n",
"print(f\"shape: {estimates['vocals'].shape}, sampling rate: {rate} Hz, length: {estimates['vocals'].shape[1] / rate:.2f} s\")"
],
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"text": [
"separating... done\n",
"shape: (2, 11934998), sampling rate: 48000 Hz, length: 248.65 s\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment