Last active
May 18, 2023 13:42
-
-
Save Mizuho32/fba4105ab95fad1e64b9cf1421c21597 to your computer and use it in GitHub Desktop.
test_esc.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import json\n", | |
"import torch\n", | |
"import librosa\n", | |
"from model.htsat import HTSAT_Swin_Transformer\n", | |
"import esc_config as config\n", | |
"import numpy as np\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Audio_Classification:\n", | |
" def __init__(self, model_path, config):\n", | |
" super().__init__()\n", | |
"\n", | |
" self.device = torch.device('cuda')\n", | |
" self.sed_model = HTSAT_Swin_Transformer(\n", | |
" spec_size=config.htsat_spec_size,\n", | |
" patch_size=config.htsat_patch_size,\n", | |
" in_chans=1,\n", | |
" num_classes=config.classes_num,\n", | |
" window_size=config.htsat_window_size,\n", | |
" config = config,\n", | |
" depths = config.htsat_depth,\n", | |
" embed_dim = config.htsat_dim,\n", | |
" patch_stride=config.htsat_stride,\n", | |
" num_heads=config.htsat_num_head\n", | |
" )\n", | |
" ckpt = torch.load(model_path, map_location=\"cuda\")\n", | |
" temp_ckpt = {}\n", | |
" for key in ckpt[\"state_dict\"]:\n", | |
" temp_ckpt[key[10:]] = ckpt['state_dict'][key]\n", | |
" self.sed_model.load_state_dict(temp_ckpt)\n", | |
" self.sed_model.to(self.device)\n", | |
" self.sed_model.eval()\n", | |
"\n", | |
"\n", | |
" def predict(self, audiofile):\n", | |
"\n", | |
" if audiofile:\n", | |
" waveform, sr = librosa.load(audiofile, sr=32000)\n", | |
"\n", | |
" with torch.no_grad():\n", | |
" x = torch.from_numpy(waveform).float().to(self.device)\n", | |
" output_dict = self.sed_model(x[None, :], None, True)\n", | |
" pred = output_dict['clipwise_output']\n", | |
" pred_post = pred[0].detach().cpu().numpy()\n", | |
" pred_label = np.argmax(pred_post)\n", | |
" pred_prob = np.max(pred_post)\n", | |
" return pred_label, pred_prob" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/mizuho/prjs/vsongrecog/hts-ast/data/venv/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n", | |
" return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" | |
] | |
}, | |
{ | |
"ename": "RuntimeError", | |
"evalue": "Input and output sizes should be greater than 0, but got input (H: 0, W: 64) output (H: 1024, W: 64)", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[3], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m Audiocls \u001b[38;5;241m=\u001b[39m Audio_Classification(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m./data/weights/HTSAT_ESC_exp=1_fold=1_acc=0.985.ckpt\u001b[39m\u001b[38;5;124m'\u001b[39m, config)\n\u001b[0;32m----> 3\u001b[0m pred_label, pred_prob \u001b[38;5;241m=\u001b[39m \u001b[43mAudiocls\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m../media/zzNdwF40ID8_32000Hz.wav\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAudiocls predict output: \u001b[39m\u001b[38;5;124m'\u001b[39m, pred_label, pred_prob)\u001b[38;5;66;03m#, gd[\"1-7456-A-13.wav\"])\u001b[39;00m\n", | |
"Cell \u001b[0;32mIn[2], line 34\u001b[0m, in \u001b[0;36mAudio_Classification.predict\u001b[0;34m(self, audiofile)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 33\u001b[0m x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(waveform)\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m---> 34\u001b[0m output_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msed_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m pred \u001b[38;5;241m=\u001b[39m output_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mclipwise_output\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 36\u001b[0m pred_post \u001b[38;5;241m=\u001b[39m pred[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n", | |
"File \u001b[0;32m~/prjs/vsongrecog/hts-ast/data/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", | |
"File \u001b[0;32m~/prjs/vsongrecog/hts-ast/model/htsat.py:781\u001b[0m, in \u001b[0;36mHTSAT_Swin_Transformer.forward\u001b[0;34m(self, x, mixup_lambda, infer_mode)\u001b[0m\n\u001b[1;32m 779\u001b[0m repeat_ratio \u001b[38;5;241m=\u001b[39m math\u001b[38;5;241m.\u001b[39mfloor(target_T \u001b[38;5;241m/\u001b[39m frame_num)\n\u001b[1;32m 780\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrepeat(repeats\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m1\u001b[39m,repeat_ratio,\u001b[38;5;241m1\u001b[39m))\n\u001b[0;32m--> 781\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape_wav2img\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m output_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward_features(x)\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39menable_repeat_mode:\n", | |
"File \u001b[0;32m~/prjs/vsongrecog/hts-ast/model/htsat.py:736\u001b[0m, in \u001b[0;36mHTSAT_Swin_Transformer.reshape_wav2img\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 734\u001b[0m \u001b[38;5;66;03m# to avoid bicubic zero error\u001b[39;00m\n\u001b[1;32m 735\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m T \u001b[38;5;241m<\u001b[39m target_T:\n\u001b[0;32m--> 736\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minterpolate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget_T\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbicubic\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m F \u001b[38;5;241m<\u001b[39m target_F:\n\u001b[1;32m 738\u001b[0m x \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39minterpolate(x, (x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m], target_F), mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbicubic\u001b[39m\u001b[38;5;124m\"\u001b[39m, align_corners\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", | |
"File \u001b[0;32m~/prjs/vsongrecog/hts-ast/data/venv/lib/python3.9/site-packages/torch/nn/functional.py:3967\u001b[0m, in \u001b[0;36minterpolate\u001b[0;34m(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)\u001b[0m\n\u001b[1;32m 3965\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m antialias:\n\u001b[1;32m 3966\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_nn\u001b[38;5;241m.\u001b[39m_upsample_bicubic2d_aa(\u001b[38;5;28minput\u001b[39m, output_size, align_corners, scale_factors)\n\u001b[0;32m-> 3967\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupsample_bicubic2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_factors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3969\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbilinear\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 3970\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot 3D input, but bilinear mode needs 4D input\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: Input and output sizes should be greater than 0, but got input (H: 0, W: 64) output (H: 1024, W: 64)" | |
] | |
} | |
], | |
"source": [ | |
"Audiocls = Audio_Classification('./data/weights/HTSAT_ESC_exp=1_fold=1_acc=0.985.ckpt', config)\n", | |
"\n", | |
"pred_label, pred_prob = Audiocls.predict(\"../media/zzNdwF40ID8_32000Hz.wav\")\n", | |
"\n", | |
"print('Audiocls predict output: ', pred_label, pred_prob)#, gd[\"1-7456-A-13.wav\"])" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.8" | |
}, | |
"vscode": { | |
"interpreter": { | |
"hash": "cbf651ad57764a85bd0d6253f55158c06e8f62339527fb5be4597f7fba08e70b" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment