Skip to content

Instantly share code, notes, and snippets.

@aidiary
Created February 19, 2024 12:03
Show Gist options
  • Save aidiary/40cb12a1345159b363aee3ddc21e31ef to your computer and use it in GitHub Desktop.
Save aidiary/40cb12a1345159b363aee3ddc21e31ef to your computer and use it in GitHub Desktop.
audiogen_demo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyNWlmoOQ5qz7A+eZM7LgufJ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/aidiary/40cb12a1345159b363aee3ddc21e31ef/audiogen_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# AudioGen\n",
"\n",
"https://github.com/facebookresearch/audiocraft/blob/main/demos/audiogen_demo.ipynb"
],
"metadata": {
"id": "5WFJ-8U7KSIn"
}
},
{
"cell_type": "code",
"source": [
"%pip install audiocraft"
],
"metadata": {
"id": "1S471mOQKRJZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 訓練済みモデルをロード"
],
"metadata": {
"id": "4hW_grOKLQHV"
}
},
{
"cell_type": "code",
"source": [
"from audiocraft.models import AudioGen\n",
"\n",
"# https://huggingface.co/facebook/audiogen-medium\n",
"model = AudioGen.get_pretrained(\"facebook/audiogen-medium\")"
],
"metadata": {
"id": "-91JOQxtKVgX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 生成パラメータの設定\n",
"# use_sampling(bool、オプション):Trueの場合はサンプリングを使用し、それ以外の場合はargmaxデコーディングを行います。デフォルトはTrueです。\n",
"# top_k(int、オプション):サンプリングに使用されるtop_k。デフォルトは250です。\n",
"# top_p(float、オプション):サンプリングに使用されるtop_p、0に設定されている場合はtop_kが使用されます。デフォルトは0.0です。\n",
"# temperature(float、オプション):softmax温度パラメーター。デフォルトは1.0です。\n",
"# duration(float、オプション):生成された波形の持続時間。デフォルトは10.0です。\n",
"# cfg_coef(float、オプション):クラシファイアフリーガイダンスに使用される係数。デフォルトは3.0です。\n",
"model.set_generation_params(\n",
" use_sampling=True,\n",
" top_k=250,\n",
" duration=5\n",
")"
],
"metadata": {
"id": "1cXM-qRdK34b"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 与えた音声の続きを生成する"
],
"metadata": {
"id": "p_R2y_TdLo7Z"
}
},
{
"cell_type": "code",
"source": [
"import math\n",
"import torchaudio\n",
"import torch\n",
"from audiocraft.utils.notebook import display_audio\n",
"\n",
"\n",
"def get_bip_bip(bip_duration=0.125, frequency=440, duration=0.5, sample_rate=16000, device=\"cuda\"):\n",
" \"\"\"指定された周波数でビープ音の系列を生成します。\"\"\"\n",
" t = torch.arange(int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
" wav = torch.cos(2 * math.pi * frequency * t)[None]\n",
" tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
" envelope = (tp >= 0.5).float()\n",
" return wav * envelope"
],
"metadata": {
"id": "wMZsYGVRLNEA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"audio_samples = get_bip_bip(0.125)\n",
"audio_samples.shape"
],
"metadata": {
"id": "xi2npWVjMmJw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.plot(audio_samples.squeeze().cpu())\n",
"display_audio(audio_samples, 16000)"
],
"metadata": {
"id": "mD0Y9SizMsyQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# コピーされるので2つの同じサンプル与えたことになる\n",
"audio_samples.expand(2, -1, -1).shape"
],
"metadata": {
"id": "z43SB4obNLzy"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# プープーの続きの音を2つ生成する\n",
"# プロンプトは2つ与える\n",
"# プロンプトは与えても与えなくてもよい(Unconditional)\n",
"res = model.generate_continuation(\n",
" audio_samples.expand(2, -1, -1),\n",
" 16000,\n",
" # [\"Whistling with wind blowing\", \"Typing on a typewriter\"],\n",
" progress=True)"
],
"metadata": {
"id": "z8XTQ6FtMx2M"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"res.shape"
],
"metadata": {
"id": "7u36fehRNyTZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"display_audio(res, 16000)"
],
"metadata": {
"id": "gvuOgKnJNGhP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## テキストで条件付けた生成"
],
"metadata": {
"id": "aWc_4AvVPDmm"
}
},
{
"cell_type": "code",
"source": [
"from audiocraft.utils.notebook import display_audio\n",
"\n",
"output = model.generate(\n",
" descriptions=[\n",
" \"Subway train blowing its horn\",\n",
" \"A cat meowing\",\n",
" ],\n",
" progress=True\n",
")\n",
"\n",
"display_audio(output, sample_rate=16000)"
],
"metadata": {
"id": "EeIemxfoNvev"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment