Skip to content

Instantly share code, notes, and snippets.

@Mizuho32
Last active May 18, 2023 13:42
Show Gist options
  • Save Mizuho32/fba4105ab95fad1e64b9cf1421c21597 to your computer and use it in GitHub Desktop.
Save Mizuho32/fba4105ab95fad1e64b9cf1421c21597 to your computer and use it in GitHub Desktop.
test_esc.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import librosa\n",
"from model.htsat import HTSAT_Swin_Transformer\n",
"import esc_config as config\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class Audio_Classification:\n",
" def __init__(self, model_path, config):\n",
" super().__init__()\n",
"\n",
" self.device = torch.device('cuda')\n",
" self.sed_model = HTSAT_Swin_Transformer(\n",
" spec_size=config.htsat_spec_size,\n",
" patch_size=config.htsat_patch_size,\n",
" in_chans=1,\n",
" num_classes=config.classes_num,\n",
" window_size=config.htsat_window_size,\n",
" config = config,\n",
" depths = config.htsat_depth,\n",
" embed_dim = config.htsat_dim,\n",
" patch_stride=config.htsat_stride,\n",
" num_heads=config.htsat_num_head\n",
" )\n",
" ckpt = torch.load(model_path, map_location=\"cuda\")\n",
" temp_ckpt = {}\n",
" for key in ckpt[\"state_dict\"]:\n",
" temp_ckpt[key[10:]] = ckpt['state_dict'][key]\n",
" self.sed_model.load_state_dict(temp_ckpt)\n",
" self.sed_model.to(self.device)\n",
" self.sed_model.eval()\n",
"\n",
"\n",
" def predict(self, audiofile):\n",
"\n",
" if audiofile:\n",
" waveform, sr = librosa.load(audiofile, sr=32000)\n",
"\n",
" with torch.no_grad():\n",
" x = torch.from_numpy(waveform).float().to(self.device)\n",
" output_dict = self.sed_model(x[None, :], None, True)\n",
" pred = output_dict['clipwise_output']\n",
" pred_post = pred[0].detach().cpu().numpy()\n",
" pred_label = np.argmax(pred_post)\n",
" pred_prob = np.max(pred_post)\n",
" return pred_label, pred_prob"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mizuho/prjs/vsongrecog/hts-ast/data/venv/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
" return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n"
]
},
{
"ename": "RuntimeError",
"evalue": "Input and output sizes should be greater than 0, but got input (H: 0, W: 64) output (H: 1024, W: 64)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m Audiocls \u001b[38;5;241m=\u001b[39m Audio_Classification(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m./data/weights/HTSAT_ESC_exp=1_fold=1_acc=0.985.ckpt\u001b[39m\u001b[38;5;124m'\u001b[39m, config)\n\u001b[0;32m----> 3\u001b[0m pred_label, pred_prob \u001b[38;5;241m=\u001b[39m \u001b[43mAudiocls\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m../media/zzNdwF40ID8_32000Hz.wav\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAudiocls predict output: \u001b[39m\u001b[38;5;124m'\u001b[39m, pred_label, pred_prob)\u001b[38;5;66;03m#, gd[\"1-7456-A-13.wav\"])\u001b[39;00m\n",
"Cell \u001b[0;32mIn[2], line 34\u001b[0m, in \u001b[0;36mAudio_Classification.predict\u001b[0;34m(self, audiofile)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 33\u001b[0m x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(waveform)\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m---> 34\u001b[0m output_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msed_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m pred \u001b[38;5;241m=\u001b[39m output_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mclipwise_output\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 36\u001b[0m pred_post \u001b[38;5;241m=\u001b[39m pred[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n",
"File \u001b[0;32m~/prjs/vsongrecog/hts-ast/data/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/prjs/vsongrecog/hts-ast/model/htsat.py:781\u001b[0m, in \u001b[0;36mHTSAT_Swin_Transformer.forward\u001b[0;34m(self, x, mixup_lambda, infer_mode)\u001b[0m\n\u001b[1;32m 779\u001b[0m repeat_ratio \u001b[38;5;241m=\u001b[39m math\u001b[38;5;241m.\u001b[39mfloor(target_T \u001b[38;5;241m/\u001b[39m frame_num)\n\u001b[1;32m 780\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrepeat(repeats\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m1\u001b[39m,repeat_ratio,\u001b[38;5;241m1\u001b[39m))\n\u001b[0;32m--> 781\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape_wav2img\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m output_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward_features(x)\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39menable_repeat_mode:\n",
"File \u001b[0;32m~/prjs/vsongrecog/hts-ast/model/htsat.py:736\u001b[0m, in \u001b[0;36mHTSAT_Swin_Transformer.reshape_wav2img\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 734\u001b[0m \u001b[38;5;66;03m# to avoid bicubic zero error\u001b[39;00m\n\u001b[1;32m 735\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m T \u001b[38;5;241m<\u001b[39m target_T:\n\u001b[0;32m--> 736\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minterpolate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget_T\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbicubic\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m F \u001b[38;5;241m<\u001b[39m target_F:\n\u001b[1;32m 738\u001b[0m x \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39minterpolate(x, (x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m], target_F), mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbicubic\u001b[39m\u001b[38;5;124m\"\u001b[39m, align_corners\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/prjs/vsongrecog/hts-ast/data/venv/lib/python3.9/site-packages/torch/nn/functional.py:3967\u001b[0m, in \u001b[0;36minterpolate\u001b[0;34m(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)\u001b[0m\n\u001b[1;32m 3965\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m antialias:\n\u001b[1;32m 3966\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_nn\u001b[38;5;241m.\u001b[39m_upsample_bicubic2d_aa(\u001b[38;5;28minput\u001b[39m, output_size, align_corners, scale_factors)\n\u001b[0;32m-> 3967\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupsample_bicubic2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_factors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3969\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbilinear\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 3970\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot 3D input, but bilinear mode needs 4D input\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mRuntimeError\u001b[0m: Input and output sizes should be greater than 0, but got input (H: 0, W: 64) output (H: 1024, W: 64)"
]
}
],
"source": [
"Audiocls = Audio_Classification('./data/weights/HTSAT_ESC_exp=1_fold=1_acc=0.985.ckpt', config)\n",
"\n",
"pred_label, pred_prob = Audiocls.predict(\"../media/zzNdwF40ID8_32000Hz.wav\")\n",
"\n",
"print('Audiocls predict output: ', pred_label, pred_prob)#, gd[\"1-7456-A-13.wav\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.8"
},
"vscode": {
"interpreter": {
"hash": "cbf651ad57764a85bd0d6253f55158c06e8f62339527fb5be4597f7fba08e70b"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment