Created
June 23, 2021 17:44
-
-
Save ShigekiKarita/c3b513ce5e3bed9aeda726c4d2c2e200 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "0da2d661", | |
"metadata": {}, | |
"source": [ | |
"# CTC implementation with PyTorch JIT\n", | |
"\n", | |
"Requires pytorch 1.9.0 and editdistance pip packages\n", | |
"And espnet AN4 data prep (wav.scp and text)\n", | |
"\n", | |
"author: Shigeki Karita (karita@ieee.org)\n", | |
"\n", | |
"https://pytorch.org/docs/stable/jit.html" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "171bf40b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:62: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n", | |
" warnings.warn(\"dropout option adds dropout after all but last \"\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"#params:\t 125\n", | |
"out shape:\t torch.Size([1])\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"graph(%self : __torch__.SpeechModel,\n", | |
" %x.1 : Tensor,\n", | |
" %xlen.1 : Tensor,\n", | |
" %y.1 : Tensor,\n", | |
" %ylen.1 : Tensor):\n", | |
" %5 : int[] = prim::Constant[value=[2]]()\n", | |
" %6 : int[] = prim::Constant[value=[1]]()\n", | |
" %7 : int[] = prim::Constant[value=[0]]()\n", | |
" %8 : int[] = prim::Constant[value=[5]]()\n", | |
" %9 : bool = prim::Constant[value=0]()\n", | |
" %10 : int = prim::Constant[value=0]() # <ipython-input-1-14e2ee0fb7fa>:85:24\n", | |
" %11 : int = prim::Constant[value=1]() # <ipython-input-1-14e2ee0fb7fa>:85:27\n", | |
" %12 : str = prim::Constant[value=\"input.size(-1) must be equal to input_size. Expected {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:206:16\n", | |
" %13 : str = prim::Constant[value=\"input must have {} dimensions, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:202:16\n", | |
" %14 : int = prim::Constant[value=3]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:199:63\n", | |
" %15 : str = prim::Constant[value=\"Expected hidden[0] size {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:622:31\n", | |
" %16 : str = prim::Constant[value=\"Expected hidden[1] size {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:624:31\n", | |
" %17 : bool = prim::Constant[value=1]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:66\n", | |
" %18 : str = prim::Constant[value=\"Expected more than 1 value per channel when training, got input size {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" %19 : float = prim::Constant[value=1.0000000000000001e-05]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/normalization.py:253:60\n", | |
" %20 : float = prim::Constant[value=0.5]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/dropout.py:58:32\n", | |
" %21 : int = prim::Constant[value=2]() # <ipython-input-1-14e2ee0fb7fa>:69:53\n", | |
" %22 : NoneType = prim::Constant()\n", | |
" %23 : int = prim::Constant[value=-1]() # <ipython-input-1-14e2ee0fb7fa>:72:29\n", | |
" %24 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name=\"encoder\"](%self)\n", | |
" %25 : Tensor = aten::slice(%x.1, %10, %22, %22, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n", | |
" %314 : Tensor = prim::profile[profiled_type=Float(2, 1000, strides=[1000, 1], requires_grad=0, device=cpu)](%25)\n", | |
" %26 : Tensor = aten::unsqueeze(%314, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n", | |
" %315 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%26)\n", | |
" %27 : Tensor = aten::slice(%315, %21, %22, %22, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n", | |
" %28 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"0\"](%24)\n", | |
" %29 : __torch__.torch.nn.modules.conv.Conv1d = prim::GetAttr[name=\"1\"](%24)\n", | |
" %30 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"2\"](%24)\n", | |
" %31 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"3\"](%24)\n", | |
" %32 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"5\"](%24)\n", | |
" %33 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"6\"](%24)\n", | |
" %34 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"7\"](%24)\n", | |
" %35 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"9\"](%24)\n", | |
" %36 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"10\"](%24)\n", | |
" %37 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"11\"](%24)\n", | |
" %38 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"13\"](%24)\n", | |
" %39 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"14\"](%24)\n", | |
" %40 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"15\"](%24)\n", | |
" %41 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"17\"](%24)\n", | |
" %42 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"18\"](%24)\n", | |
" %43 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"19\"](%24)\n", | |
" %44 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv1d = prim::GetAttr[name=\"21\"](%24)\n", | |
" %45 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"22\"](%24)\n", | |
" %46 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"23\"](%24)\n", | |
" %47 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv1d = prim::GetAttr[name=\"25\"](%24)\n", | |
" %48 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"26\"](%24)\n", | |
" %49 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"27\"](%24)\n", | |
" %50 : Tensor = prim::GetAttr[name=\"weight\"](%28)\n", | |
" %51 : Tensor = prim::GetAttr[name=\"bias\"](%28)\n", | |
" %316 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n", | |
" %52 : int = aten::size(%316, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %317 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n", | |
" %53 : int = aten::size(%317, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
" %54 : int = aten::mul(%52, %53) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %55 : int = aten::floordiv(%54, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %56 : int[] = prim::ListConstruct(%55, %11)\n", | |
" %318 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n", | |
" %57 : int[] = aten::size(%318) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %58 : int[] = aten::slice(%57, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %59 : int[] = aten::list(%58) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
" %60 : int[] = aten::add(%56, %59) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
" %size_prods.2 : int = aten::__getitem__(%60, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
" %62 : int = aten::len(%60) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %63 : int = aten::sub(%62, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %size_prods.4 : int = prim::Loop(%63, %17, %size_prods.2) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
" block0(%i.2 : int, %size_prods.12 : int):\n", | |
" %67 : int = aten::add(%i.2, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
" %68 : int = aten::__getitem__(%60, %67) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
" %size_prods.14 : int = aten::mul(%size_prods.12, %68) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
" -> (%17, %size_prods.14)\n", | |
" %70 : bool = aten::eq(%size_prods.4, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
" = prim::If(%70) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
" block0():\n", | |
" %71 : str = aten::format(%18, %60) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" = prim::RaiseException(%71) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %input.5 : Tensor = aten::group_norm(%27, %11, %50, %51, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
" %73 : Tensor = prim::GetAttr[name=\"weight\"](%29)\n", | |
" %74 : Tensor? = prim::GetAttr[name=\"bias\"](%29)\n", | |
" %input.9 : Tensor = aten::conv1d(%input.5, %73, %74, %8, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
" %76 : bool = prim::GetAttr[name=\"training\"](%30)\n", | |
" %319 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.9)\n", | |
" %input.13 : Tensor = aten::dropout(%319, %20, %76) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
" %78 : Tensor = prim::GetAttr[name=\"weight\"](%31)\n", | |
" %79 : Tensor = prim::GetAttr[name=\"bias\"](%31)\n", | |
" %320 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n", | |
" %80 : int = aten::size(%320, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %321 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n", | |
" %81 : int = aten::size(%321, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
" %82 : int = aten::mul(%80, %81) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %83 : int = aten::floordiv(%82, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %84 : int[] = prim::ListConstruct(%83, %11)\n", | |
" %322 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n", | |
" %85 : int[] = aten::size(%322) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %86 : int[] = aten::slice(%85, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %87 : int[] = aten::list(%86) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
" %88 : int[] = aten::add(%84, %87) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
" %size_prods.16 : int = aten::__getitem__(%88, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
" %90 : int = aten::len(%88) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %91 : int = aten::sub(%90, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %size_prods.18 : int = prim::Loop(%91, %17, %size_prods.16) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
" block0(%i.4 : int, %size_prods.20 : int):\n", | |
" %95 : int = aten::add(%i.4, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
" %96 : int = aten::__getitem__(%88, %95) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
" %size_prods.22 : int = aten::mul(%size_prods.20, %96) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
" -> (%17, %size_prods.22)\n", | |
" %98 : bool = aten::eq(%size_prods.18, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
" = prim::If(%98) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
" block0():\n", | |
" %99 : str = aten::format(%18, %88) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" = prim::RaiseException(%99) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %323 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n", | |
" %input.17 : Tensor = aten::group_norm(%323, %11, %78, %79, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
" %324 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.17)\n", | |
" %input.21 : Tensor = aten::gelu(%324) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
" %102 : Tensor = prim::GetAttr[name=\"weight\"](%32)\n", | |
" %103 : Tensor? = prim::GetAttr[name=\"bias\"](%32)\n", | |
" %325 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.21)\n", | |
" %input.25 : Tensor = aten::conv1d(%325, %102, %103, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
" %105 : bool = prim::GetAttr[name=\"training\"](%33)\n", | |
" %326 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.25)\n", | |
" %input.29 : Tensor = aten::dropout(%326, %20, %105) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
" %107 : Tensor = prim::GetAttr[name=\"weight\"](%34)\n", | |
" %108 : Tensor = prim::GetAttr[name=\"bias\"](%34)\n", | |
" %327 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n", | |
" %109 : int = aten::size(%327, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %328 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n", | |
" %110 : int = aten::size(%328, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
" %111 : int = aten::mul(%109, %110) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %112 : int = aten::floordiv(%111, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %113 : int[] = prim::ListConstruct(%112, %11)\n", | |
" %329 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n", | |
" %114 : int[] = aten::size(%329) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %115 : int[] = aten::slice(%114, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %116 : int[] = aten::list(%115) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
" %117 : int[] = aten::add(%113, %116) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
" %size_prods.24 : int = aten::__getitem__(%117, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
" %119 : int = aten::len(%117) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %120 : int = aten::sub(%119, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %size_prods.26 : int = prim::Loop(%120, %17, %size_prods.24) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
" block0(%i.6 : int, %size_prods.28 : int):\n", | |
" %124 : int = aten::add(%i.6, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
" %125 : int = aten::__getitem__(%117, %124) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
" %size_prods.30 : int = aten::mul(%size_prods.28, %125) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
" -> (%17, %size_prods.30)\n", | |
" %127 : bool = aten::eq(%size_prods.26, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
" = prim::If(%127) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
" block0():\n", | |
" %128 : str = aten::format(%18, %117) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" = prim::RaiseException(%128) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %330 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n", | |
" %input.33 : Tensor = aten::group_norm(%330, %11, %107, %108, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
" %331 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.33)\n", | |
" %input.37 : Tensor = aten::gelu(%331) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
" %131 : Tensor = prim::GetAttr[name=\"weight\"](%35)\n", | |
" %132 : Tensor? = prim::GetAttr[name=\"bias\"](%35)\n", | |
" %332 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.37)\n", | |
" %input.41 : Tensor = aten::conv1d(%332, %131, %132, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
" %134 : bool = prim::GetAttr[name=\"training\"](%36)\n", | |
" %333 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.41)\n", | |
" %input.45 : Tensor = aten::dropout(%333, %20, %134) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
" %136 : Tensor = prim::GetAttr[name=\"weight\"](%37)\n", | |
" %137 : Tensor = prim::GetAttr[name=\"bias\"](%37)\n", | |
" %334 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n", | |
" %138 : int = aten::size(%334, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %335 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n", | |
" %139 : int = aten::size(%335, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
" %140 : int = aten::mul(%138, %139) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %141 : int = aten::floordiv(%140, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %142 : int[] = prim::ListConstruct(%141, %11)\n", | |
" %336 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n", | |
" %143 : int[] = aten::size(%336) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %144 : int[] = aten::slice(%143, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %145 : int[] = aten::list(%144) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
" %146 : int[] = aten::add(%142, %145) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
" %size_prods.32 : int = aten::__getitem__(%146, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
" %148 : int = aten::len(%146) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %149 : int = aten::sub(%148, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %size_prods.34 : int = prim::Loop(%149, %17, %size_prods.32) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
" block0(%i.8 : int, %size_prods.36 : int):\n", | |
" %153 : int = aten::add(%i.8, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
" %154 : int = aten::__getitem__(%146, %153) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
" %size_prods.38 : int = aten::mul(%size_prods.36, %154) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
" -> (%17, %size_prods.38)\n", | |
" %156 : bool = aten::eq(%size_prods.34, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
" = prim::If(%156) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
" block0():\n", | |
" %157 : str = aten::format(%18, %146) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" = prim::RaiseException(%157) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %337 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n", | |
" %input.49 : Tensor = aten::group_norm(%337, %11, %136, %137, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
" %338 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.49)\n", | |
" %input.53 : Tensor = aten::gelu(%338) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
" %160 : Tensor = prim::GetAttr[name=\"weight\"](%38)\n", | |
" %161 : Tensor? = prim::GetAttr[name=\"bias\"](%38)\n", | |
" %339 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.53)\n", | |
" %input.57 : Tensor = aten::conv1d(%339, %160, %161, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
" %163 : bool = prim::GetAttr[name=\"training\"](%39)\n", | |
" %340 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.57)\n", | |
" %input.61 : Tensor = aten::dropout(%340, %20, %163) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
" %165 : Tensor = prim::GetAttr[name=\"weight\"](%40)\n", | |
" %166 : Tensor = prim::GetAttr[name=\"bias\"](%40)\n", | |
" %341 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n", | |
" %167 : int = aten::size(%341, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %342 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n", | |
" %168 : int = aten::size(%342, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
" %169 : int = aten::mul(%167, %168) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %170 : int = aten::floordiv(%169, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %171 : int[] = prim::ListConstruct(%170, %11)\n", | |
" %343 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n", | |
" %172 : int[] = aten::size(%343) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %173 : int[] = aten::slice(%172, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %174 : int[] = aten::list(%173) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
" %175 : int[] = aten::add(%171, %174) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
" %size_prods.40 : int = aten::__getitem__(%175, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
" %177 : int = aten::len(%175) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %178 : int = aten::sub(%177, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %size_prods.42 : int = prim::Loop(%178, %17, %size_prods.40) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
" block0(%i.10 : int, %size_prods.44 : int):\n", | |
" %182 : int = aten::add(%i.10, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
" %183 : int = aten::__getitem__(%175, %182) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
" %size_prods.46 : int = aten::mul(%size_prods.44, %183) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
" -> (%17, %size_prods.46)\n", | |
" %185 : bool = aten::eq(%size_prods.42, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
" = prim::If(%185) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
" block0():\n", | |
" %186 : str = aten::format(%18, %175) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" = prim::RaiseException(%186) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %344 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n", | |
" %input.65 : Tensor = aten::group_norm(%344, %11, %165, %166, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
" %345 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.65)\n", | |
" %input.69 : Tensor = aten::gelu(%345) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
" %189 : Tensor = prim::GetAttr[name=\"weight\"](%41)\n", | |
" %190 : Tensor? = prim::GetAttr[name=\"bias\"](%41)\n", | |
" %346 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.69)\n", | |
" %input.73 : Tensor = aten::conv1d(%346, %189, %190, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
" %192 : bool = prim::GetAttr[name=\"training\"](%42)\n", | |
" %347 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.73)\n", | |
" %input.77 : Tensor = aten::dropout(%347, %20, %192) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
" %194 : Tensor = prim::GetAttr[name=\"weight\"](%43)\n", | |
" %195 : Tensor = prim::GetAttr[name=\"bias\"](%43)\n", | |
" %348 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n", | |
" %196 : int = aten::size(%348, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %349 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n", | |
" %197 : int = aten::size(%349, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
" %198 : int = aten::mul(%196, %197) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %199 : int = aten::floordiv(%198, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %200 : int[] = prim::ListConstruct(%199, %11)\n", | |
" %350 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n", | |
" %201 : int[] = aten::size(%350) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %202 : int[] = aten::slice(%201, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %203 : int[] = aten::list(%202) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
" %204 : int[] = aten::add(%200, %203) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
" %size_prods.48 : int = aten::__getitem__(%204, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
" %206 : int = aten::len(%204) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %207 : int = aten::sub(%206, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %size_prods.50 : int = prim::Loop(%207, %17, %size_prods.48) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
" block0(%i.12 : int, %size_prods.52 : int):\n", | |
" %211 : int = aten::add(%i.12, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
" %212 : int = aten::__getitem__(%204, %211) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
" %size_prods.54 : int = aten::mul(%size_prods.52, %212) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
" -> (%17, %size_prods.54)\n", | |
" %214 : bool = aten::eq(%size_prods.50, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
" = prim::If(%214) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
" block0():\n", | |
" %215 : str = aten::format(%18, %204) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" = prim::RaiseException(%215) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %351 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n", | |
" %input.81 : Tensor = aten::group_norm(%351, %11, %194, %195, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
" %352 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.81)\n", | |
" %input.85 : Tensor = aten::gelu(%352) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
" %218 : Tensor = prim::GetAttr[name=\"weight\"](%44)\n", | |
" %219 : Tensor? = prim::GetAttr[name=\"bias\"](%44)\n", | |
" %353 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.85)\n", | |
" %input.89 : Tensor = aten::conv1d(%353, %218, %219, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
" %221 : bool = prim::GetAttr[name=\"training\"](%45)\n", | |
" %354 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.89)\n", | |
" %input.93 : Tensor = aten::dropout(%354, %20, %221) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
" %223 : Tensor = prim::GetAttr[name=\"weight\"](%46)\n", | |
" %224 : Tensor = prim::GetAttr[name=\"bias\"](%46)\n", | |
" %355 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n", | |
" %225 : int = aten::size(%355, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %356 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n", | |
" %226 : int = aten::size(%356, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
" %227 : int = aten::mul(%225, %226) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %228 : int = aten::floordiv(%227, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %229 : int[] = prim::ListConstruct(%228, %11)\n", | |
" %357 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n", | |
" %230 : int[] = aten::size(%357) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %231 : int[] = aten::slice(%230, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %232 : int[] = aten::list(%231) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
" %233 : int[] = aten::add(%229, %232) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
" %size_prods.56 : int = aten::__getitem__(%233, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
" %235 : int = aten::len(%233) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %236 : int = aten::sub(%235, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %size_prods.58 : int = prim::Loop(%236, %17, %size_prods.56) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
" block0(%i.14 : int, %size_prods.60 : int):\n", | |
" %240 : int = aten::add(%i.14, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
" %241 : int = aten::__getitem__(%233, %240) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
" %size_prods.62 : int = aten::mul(%size_prods.60, %241) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
" -> (%17, %size_prods.62)\n", | |
" %243 : bool = aten::eq(%size_prods.58, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
" = prim::If(%243) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
" block0():\n", | |
" %244 : str = aten::format(%18, %233) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" = prim::RaiseException(%244) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %358 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n", | |
" %input.97 : Tensor = aten::group_norm(%358, %11, %223, %224, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
" %359 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.97)\n", | |
" %input.101 : Tensor = aten::gelu(%359) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
" %247 : Tensor = prim::GetAttr[name=\"weight\"](%47)\n", | |
" %248 : Tensor? = prim::GetAttr[name=\"bias\"](%47)\n", | |
" %360 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.101)\n", | |
" %input.105 : Tensor = aten::conv1d(%360, %247, %248, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
" %250 : bool = prim::GetAttr[name=\"training\"](%48)\n", | |
" %361 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.105)\n", | |
" %input.109 : Tensor = aten::dropout(%361, %20, %250) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
" %252 : Tensor = prim::GetAttr[name=\"weight\"](%49)\n", | |
" %253 : Tensor = prim::GetAttr[name=\"bias\"](%49)\n", | |
" %362 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n", | |
" %254 : int = aten::size(%362, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %363 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n", | |
" %255 : int = aten::size(%363, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
" %256 : int = aten::mul(%254, %255) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %257 : int = aten::floordiv(%256, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
" %258 : int[] = prim::ListConstruct(%257, %11)\n", | |
" %364 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n", | |
" %259 : int[] = aten::size(%364) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %260 : int[] = aten::slice(%259, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
" %261 : int[] = aten::list(%260) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
" %262 : int[] = aten::add(%258, %261) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
" %size_prods.1 : int = aten::__getitem__(%262, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
" %264 : int = aten::len(%262) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %265 : int = aten::sub(%264, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
" %size_prods : int = prim::Loop(%265, %17, %size_prods.1) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
" block0(%i.1 : int, %size_prods.11 : int):\n", | |
" %269 : int = aten::add(%i.1, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
" %270 : int = aten::__getitem__(%262, %269) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
" %size_prods.5 : int = aten::mul(%size_prods.11, %270) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
" -> (%17, %size_prods.5)\n", | |
" %272 : bool = aten::eq(%size_prods, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
" = prim::If(%272) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
" block0():\n", | |
" %273 : str = aten::format(%18, %262) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
" = prim::RaiseException(%273) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %365 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n", | |
" %input.113 : Tensor = aten::group_norm(%365, %11, %252, %253, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
" %366 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.113)\n", | |
" %input.117 : Tensor = aten::gelu(%366) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
" %367 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.117)\n", | |
" %h.1 : Tensor = aten::transpose(%367, %11, %21) # <ipython-input-1-14e2ee0fb7fa>:69:12\n", | |
" %277 : __torch__.torch.nn.modules.rnn.LSTM = prim::GetAttr[name=\"lstm\"](%self)\n", | |
" %368 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
" %max_batch_size.1 : int = aten::size(%368, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:658:29\n", | |
" %369 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
" %279 : int = prim::dtype(%369)\n", | |
" %370 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
" %280 : Device = prim::device(%370)\n", | |
" %281 : int[] = prim::ListConstruct(%11, %max_batch_size.1, %11)\n", | |
" %h_zeros.1 : Tensor = aten::zeros(%281, %279, %22, %280, %22) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:665:22\n", | |
" %c_zeros.1 : Tensor = aten::zeros(%281, %279, %22, %280, %22) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:668:22\n", | |
" %371 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
" %284 : int = aten::dim(%371) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:11\n", | |
" %285 : bool = aten::ne(%284, %14) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:11\n", | |
" = prim::If(%285) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:8\n", | |
" block0():\n", | |
" %286 : str = aten::format(%13, %14, %284) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:202:16\n", | |
" = prim::RaiseException(%286) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:201:12\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %372 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
" %287 : int = aten::size(%372, %23) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:30\n", | |
" %288 : bool = aten::ne(%11, %287) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:11\n", | |
" = prim::If(%288) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:8\n", | |
" block0():\n", | |
" %289 : str = aten::format(%12, %11, %287) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:206:16\n", | |
" = prim::RaiseException(%289) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:205:12\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %expected_hidden_size.3 : (int, int, int) = prim::TupleConstruct(%11, %max_batch_size.1, %11)\n", | |
" %373 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%h_zeros.1)\n", | |
" %291 : int[] = aten::size(%373) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n", | |
" %292 : int[] = prim::ListConstruct(%11, %max_batch_size.1, %11)\n", | |
" %293 : bool = aten::ne(%291, %292) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n", | |
" = prim::If(%293) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:8\n", | |
" block0():\n", | |
" %294 : int[] = aten::list(%291) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:64\n", | |
" %295 : str = aten::format(%15, %expected_hidden_size.3, %294) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:31\n", | |
" = prim::RaiseException(%295) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:12\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %374 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%c_zeros.1)\n", | |
" %296 : int[] = aten::size(%374) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n", | |
" %297 : bool = aten::ne(%296, %292) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n", | |
" = prim::If(%297) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:8\n", | |
" block0():\n", | |
" %298 : int[] = aten::list(%296) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:64\n", | |
" %299 : str = aten::format(%16, %expected_hidden_size.3, %298) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:31\n", | |
" = prim::RaiseException(%299) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:12\n", | |
" -> ()\n", | |
" block1():\n", | |
" -> ()\n", | |
" %300 : Tensor[] = prim::GetAttr[name=\"_flat_weights\"](%277)\n", | |
" %301 : bool = prim::GetAttr[name=\"training\"](%277)\n", | |
" %375 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%h_zeros.1)\n", | |
" %376 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%c_zeros.1)\n", | |
" %302 : Tensor[] = prim::ListConstruct(%375, %376)\n", | |
" %377 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
" %303 : Tensor, %304 : Tensor, %305 : Tensor = aten::lstm(%377, %302, %300, %17, %11, %20, %301, %9, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:679:21\n", | |
" %306 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name=\"fc\"](%self)\n", | |
" %307 : Tensor = prim::GetAttr[name=\"weight\"](%306)\n", | |
" %308 : Tensor = prim::GetAttr[name=\"bias\"](%306)\n", | |
" %378 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[1, 2, 1], requires_grad=1, device=cpu)](%303)\n", | |
" %379 : Tensor = prim::profile[profiled_type=Float(30, 1, strides=[1, 1], requires_grad=1, device=cpu)](%307)\n", | |
" %380 : Tensor = prim::profile[profiled_type=Float(30, strides=[1], requires_grad=1, device=cpu)](%308)\n", | |
" %h.9 : Tensor = aten::linear(%378, %379, %380) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1847:11\n", | |
" %381 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%h.9)\n", | |
" %z.1 : Tensor = aten::log_softmax(%381, %23, %22) # <ipython-input-1-14e2ee0fb7fa>:72:15\n", | |
" %382 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%z.1)\n", | |
" %h.2 : Tensor = aten::transpose(%382, %10, %11) # <ipython-input-1-14e2ee0fb7fa>:85:12\n", | |
" %383 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[30, 60, 1], requires_grad=1, device=cpu)](%h.2)\n", | |
" %312 : Tensor = aten::ctc_loss(%383, %y.1, %xlen.1, %ylen.1, %10, %11, %9) # <ipython-input-1-14e2ee0fb7fa>:86:15\n", | |
" %384 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%z.1)\n", | |
" %313 : (Tensor, Tensor) = prim::TupleConstruct(%312, %384)\n", | |
" = prim::profile()\n", | |
" return (%313)" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Jittable CTC model def\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"\n", | |
"\n", | |
"class SpeechModel(nn.Module):\n", | |
" def __init__(self, n_vocab, n_hid, n_feat=1, dropout=0.5, act=nn.GELU):\n", | |
" super().__init__()\n", | |
" self.encoder = nn.Sequential(\n", | |
" # Wav2vec 2.0 style encoder https://arxiv.org/abs/2006.11477\n", | |
" nn.GroupNorm(1, n_feat),\n", | |
" \n", | |
" nn.Conv1d(n_feat, n_hid, 10, 5),\n", | |
" nn.Dropout(dropout),\n", | |
" nn.GroupNorm(1, n_hid),\n", | |
" act(),\n", | |
"\n", | |
" nn.Conv1d(n_hid, n_hid, 3, 2),\n", | |
" nn.Dropout(dropout),\n", | |
" nn.GroupNorm(1, n_hid),\n", | |
" act(),\n", | |
" \n", | |
" nn.Conv1d(n_hid, n_hid, 3, 2),\n", | |
" nn.Dropout(dropout),\n", | |
" nn.GroupNorm(1, n_hid),\n", | |
" act(),\n", | |
" \n", | |
" nn.Conv1d(n_hid, n_hid, 3, 2),\n", | |
" nn.Dropout(dropout),\n", | |
" nn.GroupNorm(1, n_hid),\n", | |
" act(),\n", | |
" \n", | |
" nn.Conv1d(n_hid, n_hid, 3, 2),\n", | |
" nn.Dropout(dropout),\n", | |
" nn.GroupNorm(1, n_hid),\n", | |
" act(),\n", | |
" \n", | |
" nn.Conv1d(n_hid, n_hid, 2, 2),\n", | |
" nn.Dropout(dropout),\n", | |
" nn.GroupNorm(1, n_hid),\n", | |
" act(),\n", | |
"\n", | |
" nn.Conv1d(n_hid, n_hid, 2, 2),\n", | |
" nn.Dropout(dropout),\n", | |
" nn.GroupNorm(1, n_hid),\n", | |
" act(),\n", | |
" )\n", | |
" self.lstm = nn.LSTM(n_hid, n_hid, 1, dropout=dropout, batch_first=True)\n", | |
" self.fc = nn.Linear(n_hid, n_vocab)\n", | |
"\n", | |
" @torch.jit.ignore # Cannot jit?\n", | |
" def convlen(self, xlen):\n", | |
" xlen = xlen.float()\n", | |
" for m in self.encoder:\n", | |
" if isinstance(m, nn.Conv1d):\n", | |
" xlen = torch.floor((xlen - m.kernel_size[0]) / m.stride[0] + 1)\n", | |
" return xlen.long()\n", | |
" \n", | |
" @torch.jit.export\n", | |
" def inference(self, x):\n", | |
" \"\"\"Transcribe one wav sequence to ids.\"\"\"\n", | |
" ids = self.logprob(x[None])[0].argmax(-1)\n", | |
" ids = torch.unique_consecutive(ids)\n", | |
" return ids[ids != 0]\n", | |
"\n", | |
" @torch.jit.export\n", | |
" def logprob(self, x):\n", | |
" \"\"\"Predicts log softmax distribution (batch, time, class).\"\"\"\n", | |
" h = self.encoder(x[:, None, :]).transpose(1, 2)\n", | |
" h, _ = self.lstm(h)\n", | |
" h = self.fc(h)\n", | |
" return h.log_softmax(-1)\n", | |
" \n", | |
" @torch.jit.export\n", | |
" def forward(self, x, xlen, y, ylen):\n", | |
" \"\"\"Computes CTC loss.\n", | |
" \n", | |
" Args:\n", | |
" x: input feature, float (batch, time1)\n", | |
" xlen: encoded lengths, long (batch)\n", | |
" y: target ids, long (batch, time2)\n", | |
" ylen: target lengths, long (batch)\n", | |
" \"\"\"\n", | |
" z = self.logprob(x) # (batch, time1, n_vocab)\n", | |
" h = z.transpose(0, 1) # (time1, batch, n_vocab)\n", | |
" return torch.ctc_loss(h, y, xlen, ylen), z\n", | |
"\n", | |
"\n", | |
"model = torch.jit.script(SpeechModel(30, 1))\n", | |
"x = torch.randn(2, 1000)\n", | |
"xlen = model.convlen(torch.tensor([1000, 900]).long())\n", | |
"y = torch.ones(2, 10).long()\n", | |
"ylen = torch.tensor([10, 9]).long()\n", | |
"print(\"#params:\\t\", sum(v.numel() for v in model.state_dict().values()))\n", | |
"print(\"out shape:\\t\", model.inference(x[0]).shape)\n", | |
"model.graph_for(x, xlen, y, ylen)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "38f14b44", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"848it [00:00, 528494.77it/s]\n", | |
"848it [00:00, 4886.44it/s]\n", | |
"100it [00:00, 412014.15it/s]\n", | |
"100it [00:00, 4852.50it/s]\n", | |
"130it [00:00, 388638.29it/s]\n", | |
"130it [00:00, 4795.81it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"# Dataset\n", | |
"import os\n", | |
"import librosa\n", | |
"from tqdm import tqdm\n", | |
"\n", | |
"\n", | |
"class ESPnetDataset(torch.utils.data.Dataset):\n", | |
" def __init__(self, root, sr=16000):\n", | |
" self.root = root\n", | |
" self.sr = sr\n", | |
" self.texts = {}\n", | |
" self.wavs = {}\n", | |
" self.ids = []\n", | |
" self.tokenizer = None\n", | |
" self.vocab = set()\n", | |
" with open(os.path.join(root, \"text\"), \"r\") as f:\n", | |
" for line in tqdm(f):\n", | |
" i, t = line.split(maxsplit=1)\n", | |
" self.ids.append(i)\n", | |
" t = t.strip()\n", | |
" self.texts[i] = t\n", | |
" for c in t:\n", | |
" self.vocab.add(c)\n", | |
" with open(os.path.join(root, \"wav.scp\"), \"r\") as f:\n", | |
" for line in tqdm(f):\n", | |
" s = line.split()\n", | |
" path = os.path.join(root, \"../..\", s[-2])\n", | |
" self.wavs[s[0]] = torch.as_tensor(librosa.load(path, sr=sr)[0])\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.ids)\n", | |
" \n", | |
" def __getitem__(self, idx):\n", | |
" k = self.ids[idx]\n", | |
" wav = self.wavs[k]\n", | |
" return wav, self.tokenizer.encode(self.texts[k])\n", | |
" \n", | |
"\n", | |
"class Tokenizer:\n", | |
" def __init__(self, *datasets):\n", | |
" vocabs = []\n", | |
" for d in datasets:\n", | |
" vocabs.append(d.vocab)\n", | |
" d.tokenizer = self\n", | |
" tokens = sorted(list(set.union(*vocabs)))\n", | |
" self.s2i = {'[blank]': 0, '[unk]': 1}\n", | |
" self.i2s = {0: '[blank]', 1: '[unk]'}\n", | |
" for i, token in enumerate(tokens, len(self.s2i)):\n", | |
" self.i2s[i] = token\n", | |
" self.s2i[token] = i\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.s2i)\n", | |
"\n", | |
" def encode(self, s):\n", | |
" return torch.tensor([self.s2i[c] for c in s])\n", | |
"\n", | |
" def decode(self, e):\n", | |
" return \"\".join(self.i2s[i.item()] for i in e)\n", | |
"\n", | |
" \n", | |
"trainset = ESPnetDataset('../egs/an4/asr1/data/train_nodev/')\n", | |
"devset = ESPnetDataset('../egs/an4/asr1/data/train_dev/')\n", | |
"testset = ESPnetDataset('../egs/an4/asr1/data/test')\n", | |
"tokenizer = Tokenizer(trainset, devset, testset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "eb1be0e8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/cuda/__init__.py:83: UserWarning: \n", | |
" Found GPU%d %s which is of cuda capability %d.%d.\n", | |
" PyTorch no longer supports this GPU because it is too old.\n", | |
" The minimum cuda capability supported by this library is %d.%d.\n", | |
" \n", | |
" warnings.warn(old_gpu_warn.format(d, name, major, minor, min_arch // 10, min_arch % 10))\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"max xlen: 102400\n", | |
"max ylen: 57\n", | |
"#vocab: 29\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at /opt/conda/conda-bld/pytorch_1623448278899/work/aten/src/ATen/native/cudnn/RNN.cpp:924.)\n", | |
" return forward_call(*input, **kwargs)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch: 0\n", | |
"sec/epoch: 3.954852342605591\n", | |
"train loss: tensor(5.2251)\n", | |
"train cer: tensor(0.9882)\n", | |
"dev loss: tensor(3.0169)\n", | |
"dev cer: tensor(1.)\n", | |
"epoch: 5\n", | |
"sec/epoch: 3.640742063522339\n", | |
"train loss: tensor(2.9098)\n", | |
"train cer: tensor(1.)\n", | |
"dev loss: tensor(2.9442)\n", | |
"dev cer: tensor(1.)\n", | |
"epoch: 10\n", | |
"sec/epoch: 3.664271354675293\n", | |
"train loss: tensor(2.7860)\n", | |
"train cer: tensor(0.9860)\n", | |
"dev loss: tensor(2.7893)\n", | |
"dev cer: tensor(0.9835)\n", | |
"epoch: 15\n", | |
"sec/epoch: 3.6453144550323486\n", | |
"train loss: tensor(2.6531)\n", | |
"train cer: tensor(0.9509)\n", | |
"dev loss: tensor(2.6637)\n", | |
"dev cer: tensor(0.9455)\n", | |
"epoch: 20\n", | |
"sec/epoch: 3.6640310287475586\n", | |
"train loss: tensor(2.4639)\n", | |
"train cer: tensor(0.9403)\n", | |
"dev loss: tensor(2.5285)\n", | |
"dev cer: tensor(0.9422)\n", | |
"epoch: 25\n", | |
"sec/epoch: 3.6508846282958984\n", | |
"train loss: tensor(2.2603)\n", | |
"train cer: tensor(0.7138)\n", | |
"dev loss: tensor(2.3292)\n", | |
"dev cer: tensor(0.7413)\n", | |
"epoch: 30\n", | |
"sec/epoch: 3.669067859649658\n", | |
"train loss: tensor(2.0850)\n", | |
"train cer: tensor(0.6624)\n", | |
"dev loss: tensor(2.1549)\n", | |
"dev cer: tensor(0.6562)\n", | |
"epoch: 35\n", | |
"sec/epoch: 3.6770923137664795\n", | |
"train loss: tensor(1.8892)\n", | |
"train cer: tensor(0.5947)\n", | |
"dev loss: tensor(2.0249)\n", | |
"dev cer: tensor(0.6347)\n", | |
"epoch: 40\n", | |
"sec/epoch: 3.6583163738250732\n", | |
"train loss: tensor(1.6512)\n", | |
"train cer: tensor(0.4884)\n", | |
"dev loss: tensor(1.8241)\n", | |
"dev cer: tensor(0.5076)\n", | |
"epoch: 45\n", | |
"sec/epoch: 3.678729295730591\n", | |
"train loss: tensor(1.4269)\n", | |
"train cer: tensor(0.3909)\n", | |
"dev loss: tensor(1.7563)\n", | |
"dev cer: tensor(0.4751)\n", | |
"epoch: 50\n", | |
"sec/epoch: 3.658202886581421\n", | |
"train loss: tensor(1.2555)\n", | |
"train cer: tensor(0.3106)\n", | |
"dev loss: tensor(1.5976)\n", | |
"dev cer: tensor(0.3817)\n", | |
"epoch: 55\n", | |
"sec/epoch: 3.6857497692108154\n", | |
"train loss: tensor(1.0824)\n", | |
"train cer: tensor(0.2570)\n", | |
"dev loss: tensor(1.4412)\n", | |
"dev cer: tensor(0.3340)\n", | |
"epoch: 60\n", | |
"sec/epoch: 3.6613011360168457\n", | |
"train loss: tensor(0.9420)\n", | |
"train cer: tensor(0.2111)\n", | |
"dev loss: tensor(1.4324)\n", | |
"dev cer: tensor(0.3054)\n", | |
"epoch: 65\n", | |
"sec/epoch: 3.6761670112609863\n", | |
"train loss: tensor(0.8255)\n", | |
"train cer: tensor(0.1789)\n", | |
"dev loss: tensor(1.3962)\n", | |
"dev cer: tensor(0.2882)\n", | |
"epoch: 70\n", | |
"sec/epoch: 3.682213306427002\n", | |
"train loss: tensor(0.7396)\n", | |
"train cer: tensor(0.1536)\n", | |
"dev loss: tensor(1.3824)\n", | |
"dev cer: tensor(0.2519)\n", | |
"epoch: 75\n", | |
"sec/epoch: 3.653740882873535\n", | |
"train loss: tensor(0.6194)\n", | |
"train cer: tensor(0.1323)\n", | |
"dev loss: tensor(1.5055)\n", | |
"dev cer: tensor(0.2020)\n", | |
"epoch: 80\n", | |
"sec/epoch: 3.6865620613098145\n", | |
"train loss: tensor(0.5580)\n", | |
"train cer: tensor(0.1158)\n", | |
"dev loss: tensor(1.4025)\n", | |
"dev cer: tensor(0.2222)\n", | |
"epoch: 85\n", | |
"sec/epoch: 3.6628165245056152\n", | |
"train loss: tensor(0.4663)\n", | |
"train cer: tensor(0.0924)\n", | |
"dev loss: tensor(1.4166)\n", | |
"dev cer: tensor(0.1741)\n", | |
"epoch: 90\n", | |
"sec/epoch: 3.6864466667175293\n", | |
"train loss: tensor(0.4085)\n", | |
"train cer: tensor(0.0800)\n", | |
"dev loss: tensor(1.4796)\n", | |
"dev cer: tensor(0.1821)\n", | |
"epoch: 95\n", | |
"sec/epoch: 3.664776563644409\n", | |
"train loss: tensor(0.3649)\n", | |
"train cer: tensor(0.0709)\n", | |
"dev loss: tensor(1.4671)\n", | |
"dev cer: tensor(0.2182)\n", | |
"epoch: 100\n", | |
"sec/epoch: 3.677140235900879\n", | |
"train loss: tensor(0.3022)\n", | |
"train cer: tensor(0.0581)\n", | |
"dev loss: tensor(1.4241)\n", | |
"dev cer: tensor(0.1634)\n" | |
] | |
} | |
], | |
"source": [ | |
"# Training\n", | |
"import time\n", | |
"import itertools\n", | |
"import editdistance\n", | |
"\n", | |
"# hyperparams\n", | |
"n_batch = 16\n", | |
"n_epoch = 100\n", | |
"n_hid = 128\n", | |
"lr = 0.001\n", | |
"dropout = 0.0\n", | |
"clip = 1.0\n", | |
"device = torch.device('cuda')\n", | |
"max_xlen = max(map(len, itertools.chain(trainset.wavs.values(), devset.wavs.values(), testset.wavs.values())))\n", | |
"max_ylen = max(map(len, itertools.chain(trainset.texts.values(), devset.texts.values(), testset.texts.values())))\n", | |
"print(\"max xlen:\", max_xlen)\n", | |
"print(\"max ylen:\", max_ylen)\n", | |
"print(\"#vocab:\", len(tokenizer))\n", | |
"\n", | |
"def collate(batch):\n", | |
" # print(batch)\n", | |
" xlen, ylen = [], []\n", | |
" xpad = torch.zeros(n_batch, max_xlen)\n", | |
" ypad = torch.zeros(n_batch, max_ylen)\n", | |
" for i, (x, y) in enumerate(batch):\n", | |
" xpad[i, :len(x)] = x\n", | |
" ypad[i, :len(y)] = y\n", | |
" xlen.append(len(x))\n", | |
" ylen.append(len(y))\n", | |
" return xpad, torch.tensor(xlen), ypad, torch.tensor(ylen)\n", | |
"\n", | |
"\n", | |
"def cer(pred, plen, ypad, ylen):\n", | |
" ids = pred.argmax(-1)\n", | |
" err = 0\n", | |
" n = 0\n", | |
" for p, pl, y, yl in zip(ids, plen, ypad, ylen):\n", | |
" p = torch.unique_consecutive(p[:pl])\n", | |
" p = p[p != 0] # filter blank\n", | |
" err += editdistance.eval(p, y[:yl])\n", | |
" n += yl\n", | |
" return err / n\n", | |
"\n", | |
"\n", | |
"train_loader = torch.utils.data.DataLoader(trainset, n_batch, collate_fn=collate, shuffle=True, drop_last=True)\n", | |
"dev_loader = torch.utils.data.DataLoader(devset, n_batch, collate_fn=collate, shuffle=False, drop_last=True)\n", | |
"test_loader = torch.utils.data.DataLoader(testset, n_batch, collate_fn=collate, shuffle=False, drop_last=True)\n", | |
"\n", | |
"model = torch.jit.script(SpeechModel(len(tokenizer), n_hid=n_hid, dropout=dropout))\n", | |
"model.to(device)\n", | |
"optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", | |
"for epoch in range(n_epoch + 1):\n", | |
" model.train()\n", | |
" loss_hist = []\n", | |
" cer_hist = []\n", | |
" start = time.time()\n", | |
" for xpad, xlen, ypad, ylen in train_loader:\n", | |
" optimizer.zero_grad()\n", | |
" hlen = model.convlen(xlen)\n", | |
" loss, pred = model(xpad.to(device), hlen.to(device),\n", | |
" ypad.to(device), ylen.to(device))\n", | |
" loss.backward()\n", | |
" nn.utils.clip_grad_norm_(model.parameters(), clip)\n", | |
" optimizer.step()\n", | |
" cer_hist.append(cer(pred, hlen, ypad, ylen))\n", | |
" loss_hist.append(loss.float())\n", | |
" if epoch % 5 == 0:\n", | |
" print(\"epoch:\", epoch)\n", | |
" print(\"sec/epoch:\", time.time() - start)\n", | |
" print(\"train loss:\", torch.mean(torch.tensor(loss_hist)).float())\n", | |
" print(\"train cer:\", torch.mean(torch.tensor(cer_hist)).float())\n", | |
"\n", | |
" loss_hist = []\n", | |
" cer_hist = []\n", | |
" model.eval()\n", | |
" with torch.no_grad():\n", | |
" for xpad, xlen, ypad, ylen in dev_loader:\n", | |
" hlen = model.convlen(xlen)\n", | |
" loss, pred = model(xpad.to(device), hlen.to(device),\n", | |
" ypad.to(device), ylen.to(device))\n", | |
" cer_hist.append(cer(pred, hlen, ypad, ylen))\n", | |
" loss_hist.append(loss.float())\n", | |
" if epoch % 5 == 0:\n", | |
" print(\"dev loss:\", torch.mean(torch.tensor(loss_hist)).float())\n", | |
" print(\"dev cer:\", torch.mean(torch.tensor(cer_hist)).float())\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "0525926a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"pred: \n", | |
"truth: YES\n", | |
"pred: FIFB SFTY SIX\n", | |
"truth: FIFTY ONE FIFTY SIX\n", | |
"pred: U FIVE TWO ONE SEVEN\n", | |
"truth: ONE FIVE TWO ONE SEVEN\n", | |
"pred: ETER FIVE\n", | |
"truth: ENTER FIVE\n", | |
"pred: UBOU J IU TWE\n", | |
"truth: RUBOUT J U I P THREE TWO EIGHT\n", | |
"pred: J LR E\n", | |
"truth: J E N N I F E R\n" | |
] | |
} | |
], | |
"source": [ | |
"model.cpu()\n", | |
"model.eval()\n", | |
"for i, (xs, xlen, ys, ylen) in enumerate(dev_loader):\n", | |
" with torch.no_grad():\n", | |
" x = xs[0, :xlen[0]]\n", | |
" pred = model.inference(x)\n", | |
"\n", | |
" print(\"pred: \", tokenizer.decode(pred))\n", | |
" print(\"truth:\", tokenizer.decode(ys[0, :ylen[0]]))\n", | |
" if i == 5:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "742e96c7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.jit.save(model, 'ctc.pt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "e2f818e1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loaded = torch.jit.load('ctc.pt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "360beaf7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"pred: \n", | |
"truth: YES\n", | |
"pred: FIFB SFTY SIX\n", | |
"truth: FIFTY ONE FIFTY SIX\n", | |
"pred: U FIVE TWO ONE SEVEN\n", | |
"truth: ONE FIVE TWO ONE SEVEN\n", | |
"pred: ETER FIVE\n", | |
"truth: ENTER FIVE\n", | |
"pred: UBOU J IU TWE\n", | |
"truth: RUBOUT J U I P THREE TWO EIGHT\n", | |
"pred: J LR E\n", | |
"truth: J E N N I F E R\n" | |
] | |
} | |
], | |
"source": [ | |
"for i, (xs, xlen, ys, ylen) in enumerate(dev_loader):\n", | |
" with torch.no_grad():\n", | |
" x = xs[0, :xlen[0]]\n", | |
" pred = loaded.inference(x)\n", | |
"\n", | |
" print(\"pred: \", tokenizer.decode(pred))\n", | |
" print(\"truth:\", tokenizer.decode(ys[0, :ylen[0]]))\n", | |
" if i == 5:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "f2e5dcbd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Thu Jun 24 02:31:24 2021 \r\n", | |
"+-----------------------------------------------------------------------------+\r\n", | |
"| NVIDIA-SMI 460.80 Driver Version: 460.80 CUDA Version: 11.2 |\r\n", | |
"|-------------------------------+----------------------+----------------------+\r\n", | |
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n", | |
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n", | |
"| | | MIG M. |\r\n", | |
"|===============================+======================+======================|\r\n", | |
"| 0 GeForce GTX 760 Off | 00000000:01:00.0 N/A | N/A |\r\n", | |
"| 36% 48C P8 N/A / N/A | 121MiB / 1996MiB | N/A Default |\r\n", | |
"| | | N/A |\r\n", | |
"+-------------------------------+----------------------+----------------------+\r\n", | |
"| 1 GeForce GTX 108... Off | 00000000:02:00.0 Off | N/A |\r\n", | |
"| 0% 51C P8 10W / 250W | 2099MiB / 11178MiB | 0% Default |\r\n", | |
"| | | N/A |\r\n", | |
"+-------------------------------+----------------------+----------------------+\r\n", | |
" \r\n", | |
"+-----------------------------------------------------------------------------+\r\n", | |
"| Processes: |\r\n", | |
"| GPU GI CI PID Type Process name GPU Memory |\r\n", | |
"| ID ID Usage |\r\n", | |
"|=============================================================================|\r\n", | |
"| 1 N/A N/A 1456 G /usr/lib/xorg/Xorg 4MiB |\r\n", | |
"| 1 N/A N/A 25056 C ...net/tools/venv/bin/python 2091MiB |\r\n", | |
"+-----------------------------------------------------------------------------+\r\n" | |
] | |
} | |
], | |
"source": [ | |
"!nvidia-smi" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1c686df1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment