-
-
Save aidiary/3921ee9fbaba8297c0a55b7db701424b to your computer and use it in GitHub Desktop.
encodec_demo.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"private_outputs": true, | |
"provenance": [], | |
"authorship_tag": "ABX9TyOIAdGfbowpmJTQ0JHIT4sL", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/aidiary/3921ee9fbaba8297c0a55b7db701424b/encodec.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Encodec\n", | |
"\n", | |
"https://github.com/facebookresearch/encodec" | |
], | |
"metadata": { | |
"id": "0ar0ZVOt98w4" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "IpL-wtHj9wxN" | |
}, | |
"outputs": [], | |
"source": [ | |
"%pip install datasets" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## ダミーの音声データをロード" | |
], | |
"metadata": { | |
"id": "G3ryQMxU_sMy" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from datasets import load_dataset, Audio" | |
], | |
"metadata": { | |
"id": "6WavreZP-uWr" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"librispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n", | |
"librispeech_dummy" | |
], | |
"metadata": { | |
"id": "xCC6Bbjf_H9y" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"librispeech_dummy[0]" | |
], | |
"metadata": { | |
"id": "h84IQoh7_u7u" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"plt.plot(librispeech_dummy[0][\"audio\"][\"array\"])" | |
], | |
"metadata": { | |
"id": "GkZ-3HHKACFD" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# datasets.Audioと衝突するためリネーム\n", | |
"from IPython.display import display, Audio as IPAudio" | |
], | |
"metadata": { | |
"id": "FzpuRAYUAQmL" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(librispeech_dummy[0][\"text\"])\n", | |
"IPAudio(filename=librispeech_dummy[0][\"audio\"][\"path\"])" | |
], | |
"metadata": { | |
"id": "Ua0FOgXyAqKg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"librispeech_dummy[0][\"audio\"][\"sampling_rate\"]" | |
], | |
"metadata": { | |
"id": "jULI1eyEA1KX" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 学習済みモデルをロード" | |
], | |
"metadata": { | |
"id": "j32bM7PnBTHz" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# モデルも実装済み\n", | |
"# TODO: AutoModelでもロードできる?\n", | |
"from transformers import EncodecModel, AutoProcessor" | |
], | |
"metadata": { | |
"id": "g6r9kjzlBSeY" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = EncodecModel.from_pretrained(\"facebook/encodec_24khz\")\n", | |
"processor = AutoProcessor.from_pretrained(\"facebook/encodec_24khz\")" | |
], | |
"metadata": { | |
"id": "gRR8WVPdBPFE" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"processor" | |
], | |
"metadata": { | |
"id": "P0CffyjoBmNf" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 入力音声のサンプリングレート変換" | |
], | |
"metadata": { | |
"id": "9tTPl2WUBx2I" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"Audio(sampling_rate=processor.sampling_rate)" | |
], | |
"metadata": { | |
"id": "nXx53OBcBoA6" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 16kHzを24kHzにキャスト\n", | |
"librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\n", | |
"librispeech_dummy[0]" | |
], | |
"metadata": { | |
"id": "wJJKoqdUB5Uq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## データの前処理" | |
], | |
"metadata": { | |
"id": "b9cPsUuFCNMo" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\n", | |
"audio_sample.shape" | |
], | |
"metadata": { | |
"id": "SddcwAKHCC_y" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"inputs = processor(raw_audio=audio_sample,\n", | |
" sampling_rate=processor.sampling_rate,\n", | |
" return_tensors=\"pt\")\n", | |
"inputs.keys()" | |
], | |
"metadata": { | |
"id": "IZeIZu-sCT1r" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# モデルへの入力は3Dテンソル\n", | |
"inputs[\"input_values\"].shape" | |
], | |
"metadata": { | |
"id": "WMDgdC-LCewN" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"inputs[\"padding_mask\"].shape" | |
], | |
"metadata": { | |
"id": "Ujjxbfm5CmD7" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 推論" | |
], | |
"metadata": { | |
"id": "R87qZMqoC8n-" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 音声波形をNeural Codecに変換\n", | |
"encoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\n", | |
"encoder_outputs.keys()" | |
], | |
"metadata": { | |
"id": "HBCX3qv2CnrG" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoder_outputs[\"audio_codes\"].shape" | |
], | |
"metadata": { | |
"id": "Ditwja_vDCm-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoder_outputs[\"audio_scales\"]" | |
], | |
"metadata": { | |
"id": "Uxy53F4gDNz_" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Neural Codecから音声波形を復元\n", | |
"audio_values = model.decode(encoder_outputs.audio_codes,\n", | |
" encoder_outputs.audio_scales,\n", | |
" inputs[\"padding_mask\"])[0]\n", | |
"audio_values.shape" | |
], | |
"metadata": { | |
"id": "frLXNRbZDQO0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"IPAudio(data=audio_sample, rate=24000)" | |
], | |
"metadata": { | |
"id": "ebLEGHa_Dvka" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 復元音質\n", | |
"# この設定ではいまいちな音質(明らかに劣化がわかる)\n", | |
"IPAudio(data=audio_values.squeeze().detach().numpy(), rate=24000)" | |
], | |
"metadata": { | |
"id": "5HMZOWUrD6Ot" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## bandwidthを調整して品質を上げる" | |
], | |
"metadata": { | |
"id": "D9I862Y8EzbY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# デフォルトはbandwidth=1.5\n", | |
"# [1.5, 3.0, 6.0, 12.0, 24.0]がサポートされている\n", | |
"def reconstruction(inputs, bandwidth=1.5):\n", | |
" encoder_outputs = model.encode(inputs[\"input_values\"],\n", | |
" inputs[\"padding_mask\"],\n", | |
" bandwidth=bandwidth)\n", | |
" print(encoder_outputs.audio_codes.shape)\n", | |
"\n", | |
" audio_values = model.decode(encoder_outputs.audio_codes,\n", | |
" encoder_outputs.audio_scales,\n", | |
" inputs[\"padding_mask\"])[0]\n", | |
"\n", | |
" display(IPAudio(data=audio_values.squeeze().detach().numpy(), rate=24000))" | |
], | |
"metadata": { | |
"id": "S5uu778qEyJd" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# bandwidth=12.0くらいで元の品質に近いくらい\n", | |
"# それでも悪化がわかる\n", | |
"\n", | |
"print(\"*** original\")\n", | |
"display(IPAudio(data=audio_sample, rate=24000))\n", | |
"\n", | |
"for bandwidth in [1.5, 3.0, 6.0, 12.0, 24.0]:\n", | |
" print(\"***\", bandwidth)\n", | |
" reconstruction(inputs, bandwidth)" | |
], | |
"metadata": { | |
"id": "JWciEtiCFuK_" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment