Skip to content

Instantly share code, notes, and snippets.

@aidiary
Last active February 19, 2024 11:36
Show Gist options
  • Save aidiary/3921ee9fbaba8297c0a55b7db701424b to your computer and use it in GitHub Desktop.
Save aidiary/3921ee9fbaba8297c0a55b7db701424b to your computer and use it in GitHub Desktop.
encodec_demo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"authorship_tag": "ABX9TyOIAdGfbowpmJTQ0JHIT4sL",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/aidiary/3921ee9fbaba8297c0a55b7db701424b/encodec.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Encodec\n",
"\n",
"https://github.com/facebookresearch/encodec"
],
"metadata": {
"id": "0ar0ZVOt98w4"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IpL-wtHj9wxN"
},
"outputs": [],
"source": [
"%pip install datasets"
]
},
{
"cell_type": "markdown",
"source": [
"## ダミーの音声データをロード"
],
"metadata": {
"id": "G3ryQMxU_sMy"
}
},
{
"cell_type": "code",
"source": [
"from datasets import load_dataset, Audio"
],
"metadata": {
"id": "6WavreZP-uWr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"librispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
"librispeech_dummy"
],
"metadata": {
"id": "xCC6Bbjf_H9y"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"librispeech_dummy[0]"
],
"metadata": {
"id": "h84IQoh7_u7u"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.plot(librispeech_dummy[0][\"audio\"][\"array\"])"
],
"metadata": {
"id": "GkZ-3HHKACFD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# datasets.Audioと衝突するためリネーム\n",
"from IPython.display import display, Audio as IPAudio"
],
"metadata": {
"id": "FzpuRAYUAQmL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(librispeech_dummy[0][\"text\"])\n",
"IPAudio(filename=librispeech_dummy[0][\"audio\"][\"path\"])"
],
"metadata": {
"id": "Ua0FOgXyAqKg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"librispeech_dummy[0][\"audio\"][\"sampling_rate\"]"
],
"metadata": {
"id": "jULI1eyEA1KX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 学習済みモデルをロード"
],
"metadata": {
"id": "j32bM7PnBTHz"
}
},
{
"cell_type": "code",
"source": [
"# モデルも実装済み\n",
"# TODO: AutoModelでもロードできる?\n",
"from transformers import EncodecModel, AutoProcessor"
],
"metadata": {
"id": "g6r9kjzlBSeY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = EncodecModel.from_pretrained(\"facebook/encodec_24khz\")\n",
"processor = AutoProcessor.from_pretrained(\"facebook/encodec_24khz\")"
],
"metadata": {
"id": "gRR8WVPdBPFE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"processor"
],
"metadata": {
"id": "P0CffyjoBmNf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 入力音声のサンプリングレート変換"
],
"metadata": {
"id": "9tTPl2WUBx2I"
}
},
{
"cell_type": "code",
"source": [
"Audio(sampling_rate=processor.sampling_rate)"
],
"metadata": {
"id": "nXx53OBcBoA6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 16kHzを24kHzにキャスト\n",
"librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\n",
"librispeech_dummy[0]"
],
"metadata": {
"id": "wJJKoqdUB5Uq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## データの前処理"
],
"metadata": {
"id": "b9cPsUuFCNMo"
}
},
{
"cell_type": "code",
"source": [
"audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\n",
"audio_sample.shape"
],
"metadata": {
"id": "SddcwAKHCC_y"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"inputs = processor(raw_audio=audio_sample,\n",
" sampling_rate=processor.sampling_rate,\n",
" return_tensors=\"pt\")\n",
"inputs.keys()"
],
"metadata": {
"id": "IZeIZu-sCT1r"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# モデルへの入力は3Dテンソル\n",
"inputs[\"input_values\"].shape"
],
"metadata": {
"id": "WMDgdC-LCewN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"inputs[\"padding_mask\"].shape"
],
"metadata": {
"id": "Ujjxbfm5CmD7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 推論"
],
"metadata": {
"id": "R87qZMqoC8n-"
}
},
{
"cell_type": "code",
"source": [
"# 音声波形をNeural Codecに変換\n",
"encoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\n",
"encoder_outputs.keys()"
],
"metadata": {
"id": "HBCX3qv2CnrG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"encoder_outputs[\"audio_codes\"].shape"
],
"metadata": {
"id": "Ditwja_vDCm-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"encoder_outputs[\"audio_scales\"]"
],
"metadata": {
"id": "Uxy53F4gDNz_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Neural Codecから音声波形を復元\n",
"audio_values = model.decode(encoder_outputs.audio_codes,\n",
" encoder_outputs.audio_scales,\n",
" inputs[\"padding_mask\"])[0]\n",
"audio_values.shape"
],
"metadata": {
"id": "frLXNRbZDQO0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"IPAudio(data=audio_sample, rate=24000)"
],
"metadata": {
"id": "ebLEGHa_Dvka"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 復元音質\n",
"# この設定ではいまいちな音質(明らかに劣化がわかる)\n",
"IPAudio(data=audio_values.squeeze().detach().numpy(), rate=24000)"
],
"metadata": {
"id": "5HMZOWUrD6Ot"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## bandwidthを調整して品質を上げる"
],
"metadata": {
"id": "D9I862Y8EzbY"
}
},
{
"cell_type": "code",
"source": [
"# デフォルトはbandwidth=1.5\n",
"# [1.5, 3.0, 6.0, 12.0, 24.0]がサポートされている\n",
"def reconstruction(inputs, bandwidth=1.5):\n",
" encoder_outputs = model.encode(inputs[\"input_values\"],\n",
" inputs[\"padding_mask\"],\n",
" bandwidth=bandwidth)\n",
" print(encoder_outputs.audio_codes.shape)\n",
"\n",
" audio_values = model.decode(encoder_outputs.audio_codes,\n",
" encoder_outputs.audio_scales,\n",
" inputs[\"padding_mask\"])[0]\n",
"\n",
" display(IPAudio(data=audio_values.squeeze().detach().numpy(), rate=24000))"
],
"metadata": {
"id": "S5uu778qEyJd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# bandwidth=12.0くらいで元の品質に近いくらい\n",
"# それでも悪化がわかる\n",
"\n",
"print(\"*** original\")\n",
"display(IPAudio(data=audio_sample, rate=24000))\n",
"\n",
"for bandwidth in [1.5, 3.0, 6.0, 12.0, 24.0]:\n",
" print(\"***\", bandwidth)\n",
" reconstruction(inputs, bandwidth)"
],
"metadata": {
"id": "JWciEtiCFuK_"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment