Skip to content

Instantly share code, notes, and snippets.

@ShigekiKarita
Created June 23, 2021 17:44
Show Gist options
  • Save ShigekiKarita/c3b513ce5e3bed9aeda726c4d2c2e200 to your computer and use it in GitHub Desktop.
Save ShigekiKarita/c3b513ce5e3bed9aeda726c4d2c2e200 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "0da2d661",
"metadata": {},
"source": [
"# CTC implementation with PyTorch JIT\n",
"\n",
"Requires pytorch 1.9.0 and editdistance pip packages\n",
"And espnet AN4 data prep (wav.scp and text)\n",
"\n",
"author: Shigeki Karita (karita@ieee.org)\n",
"\n",
"https://pytorch.org/docs/stable/jit.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "171bf40b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:62: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
" warnings.warn(\"dropout option adds dropout after all but last \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"#params:\t 125\n",
"out shape:\t torch.Size([1])\n"
]
},
{
"data": {
"text/plain": [
"graph(%self : __torch__.SpeechModel,\n",
" %x.1 : Tensor,\n",
" %xlen.1 : Tensor,\n",
" %y.1 : Tensor,\n",
" %ylen.1 : Tensor):\n",
" %5 : int[] = prim::Constant[value=[2]]()\n",
" %6 : int[] = prim::Constant[value=[1]]()\n",
" %7 : int[] = prim::Constant[value=[0]]()\n",
" %8 : int[] = prim::Constant[value=[5]]()\n",
" %9 : bool = prim::Constant[value=0]()\n",
" %10 : int = prim::Constant[value=0]() # <ipython-input-1-14e2ee0fb7fa>:85:24\n",
" %11 : int = prim::Constant[value=1]() # <ipython-input-1-14e2ee0fb7fa>:85:27\n",
" %12 : str = prim::Constant[value=\"input.size(-1) must be equal to input_size. Expected {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:206:16\n",
" %13 : str = prim::Constant[value=\"input must have {} dimensions, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:202:16\n",
" %14 : int = prim::Constant[value=3]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:199:63\n",
" %15 : str = prim::Constant[value=\"Expected hidden[0] size {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:622:31\n",
" %16 : str = prim::Constant[value=\"Expected hidden[1] size {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:624:31\n",
" %17 : bool = prim::Constant[value=1]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:66\n",
" %18 : str = prim::Constant[value=\"Expected more than 1 value per channel when training, got input size {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" %19 : float = prim::Constant[value=1.0000000000000001e-05]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/normalization.py:253:60\n",
" %20 : float = prim::Constant[value=0.5]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/dropout.py:58:32\n",
" %21 : int = prim::Constant[value=2]() # <ipython-input-1-14e2ee0fb7fa>:69:53\n",
" %22 : NoneType = prim::Constant()\n",
" %23 : int = prim::Constant[value=-1]() # <ipython-input-1-14e2ee0fb7fa>:72:29\n",
" %24 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name=\"encoder\"](%self)\n",
" %25 : Tensor = aten::slice(%x.1, %10, %22, %22, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n",
" %314 : Tensor = prim::profile[profiled_type=Float(2, 1000, strides=[1000, 1], requires_grad=0, device=cpu)](%25)\n",
" %26 : Tensor = aten::unsqueeze(%314, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n",
" %315 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%26)\n",
" %27 : Tensor = aten::slice(%315, %21, %22, %22, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n",
" %28 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"0\"](%24)\n",
" %29 : __torch__.torch.nn.modules.conv.Conv1d = prim::GetAttr[name=\"1\"](%24)\n",
" %30 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"2\"](%24)\n",
" %31 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"3\"](%24)\n",
" %32 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"5\"](%24)\n",
" %33 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"6\"](%24)\n",
" %34 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"7\"](%24)\n",
" %35 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"9\"](%24)\n",
" %36 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"10\"](%24)\n",
" %37 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"11\"](%24)\n",
" %38 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"13\"](%24)\n",
" %39 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"14\"](%24)\n",
" %40 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"15\"](%24)\n",
" %41 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"17\"](%24)\n",
" %42 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"18\"](%24)\n",
" %43 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"19\"](%24)\n",
" %44 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv1d = prim::GetAttr[name=\"21\"](%24)\n",
" %45 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"22\"](%24)\n",
" %46 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"23\"](%24)\n",
" %47 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv1d = prim::GetAttr[name=\"25\"](%24)\n",
" %48 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"26\"](%24)\n",
" %49 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"27\"](%24)\n",
" %50 : Tensor = prim::GetAttr[name=\"weight\"](%28)\n",
" %51 : Tensor = prim::GetAttr[name=\"bias\"](%28)\n",
" %316 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n",
" %52 : int = aten::size(%316, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %317 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n",
" %53 : int = aten::size(%317, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n",
" %54 : int = aten::mul(%52, %53) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %55 : int = aten::floordiv(%54, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %56 : int[] = prim::ListConstruct(%55, %11)\n",
" %318 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n",
" %57 : int[] = aten::size(%318) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %58 : int[] = aten::slice(%57, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %59 : int[] = aten::list(%58) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n",
" %60 : int[] = aten::add(%56, %59) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n",
" %size_prods.2 : int = aten::__getitem__(%60, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n",
" %62 : int = aten::len(%60) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %63 : int = aten::sub(%62, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %size_prods.4 : int = prim::Loop(%63, %17, %size_prods.2) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n",
" block0(%i.2 : int, %size_prods.12 : int):\n",
" %67 : int = aten::add(%i.2, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n",
" %68 : int = aten::__getitem__(%60, %67) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n",
" %size_prods.14 : int = aten::mul(%size_prods.12, %68) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n",
" -> (%17, %size_prods.14)\n",
" %70 : bool = aten::eq(%size_prods.4, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n",
" = prim::If(%70) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n",
" block0():\n",
" %71 : str = aten::format(%18, %60) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" = prim::RaiseException(%71) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %input.5 : Tensor = aten::group_norm(%27, %11, %50, %51, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n",
" %73 : Tensor = prim::GetAttr[name=\"weight\"](%29)\n",
" %74 : Tensor? = prim::GetAttr[name=\"bias\"](%29)\n",
" %input.9 : Tensor = aten::conv1d(%input.5, %73, %74, %8, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n",
" %76 : bool = prim::GetAttr[name=\"training\"](%30)\n",
" %319 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.9)\n",
" %input.13 : Tensor = aten::dropout(%319, %20, %76) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n",
" %78 : Tensor = prim::GetAttr[name=\"weight\"](%31)\n",
" %79 : Tensor = prim::GetAttr[name=\"bias\"](%31)\n",
" %320 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n",
" %80 : int = aten::size(%320, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %321 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n",
" %81 : int = aten::size(%321, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n",
" %82 : int = aten::mul(%80, %81) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %83 : int = aten::floordiv(%82, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %84 : int[] = prim::ListConstruct(%83, %11)\n",
" %322 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n",
" %85 : int[] = aten::size(%322) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %86 : int[] = aten::slice(%85, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %87 : int[] = aten::list(%86) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n",
" %88 : int[] = aten::add(%84, %87) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n",
" %size_prods.16 : int = aten::__getitem__(%88, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n",
" %90 : int = aten::len(%88) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %91 : int = aten::sub(%90, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %size_prods.18 : int = prim::Loop(%91, %17, %size_prods.16) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n",
" block0(%i.4 : int, %size_prods.20 : int):\n",
" %95 : int = aten::add(%i.4, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n",
" %96 : int = aten::__getitem__(%88, %95) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n",
" %size_prods.22 : int = aten::mul(%size_prods.20, %96) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n",
" -> (%17, %size_prods.22)\n",
" %98 : bool = aten::eq(%size_prods.18, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n",
" = prim::If(%98) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n",
" block0():\n",
" %99 : str = aten::format(%18, %88) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" = prim::RaiseException(%99) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %323 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n",
" %input.17 : Tensor = aten::group_norm(%323, %11, %78, %79, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n",
" %324 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.17)\n",
" %input.21 : Tensor = aten::gelu(%324) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n",
" %102 : Tensor = prim::GetAttr[name=\"weight\"](%32)\n",
" %103 : Tensor? = prim::GetAttr[name=\"bias\"](%32)\n",
" %325 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.21)\n",
" %input.25 : Tensor = aten::conv1d(%325, %102, %103, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n",
" %105 : bool = prim::GetAttr[name=\"training\"](%33)\n",
" %326 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.25)\n",
" %input.29 : Tensor = aten::dropout(%326, %20, %105) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n",
" %107 : Tensor = prim::GetAttr[name=\"weight\"](%34)\n",
" %108 : Tensor = prim::GetAttr[name=\"bias\"](%34)\n",
" %327 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n",
" %109 : int = aten::size(%327, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %328 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n",
" %110 : int = aten::size(%328, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n",
" %111 : int = aten::mul(%109, %110) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %112 : int = aten::floordiv(%111, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %113 : int[] = prim::ListConstruct(%112, %11)\n",
" %329 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n",
" %114 : int[] = aten::size(%329) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %115 : int[] = aten::slice(%114, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %116 : int[] = aten::list(%115) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n",
" %117 : int[] = aten::add(%113, %116) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n",
" %size_prods.24 : int = aten::__getitem__(%117, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n",
" %119 : int = aten::len(%117) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %120 : int = aten::sub(%119, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %size_prods.26 : int = prim::Loop(%120, %17, %size_prods.24) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n",
" block0(%i.6 : int, %size_prods.28 : int):\n",
" %124 : int = aten::add(%i.6, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n",
" %125 : int = aten::__getitem__(%117, %124) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n",
" %size_prods.30 : int = aten::mul(%size_prods.28, %125) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n",
" -> (%17, %size_prods.30)\n",
" %127 : bool = aten::eq(%size_prods.26, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n",
" = prim::If(%127) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n",
" block0():\n",
" %128 : str = aten::format(%18, %117) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" = prim::RaiseException(%128) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %330 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n",
" %input.33 : Tensor = aten::group_norm(%330, %11, %107, %108, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n",
" %331 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.33)\n",
" %input.37 : Tensor = aten::gelu(%331) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n",
" %131 : Tensor = prim::GetAttr[name=\"weight\"](%35)\n",
" %132 : Tensor? = prim::GetAttr[name=\"bias\"](%35)\n",
" %332 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.37)\n",
" %input.41 : Tensor = aten::conv1d(%332, %131, %132, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n",
" %134 : bool = prim::GetAttr[name=\"training\"](%36)\n",
" %333 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.41)\n",
" %input.45 : Tensor = aten::dropout(%333, %20, %134) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n",
" %136 : Tensor = prim::GetAttr[name=\"weight\"](%37)\n",
" %137 : Tensor = prim::GetAttr[name=\"bias\"](%37)\n",
" %334 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n",
" %138 : int = aten::size(%334, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %335 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n",
" %139 : int = aten::size(%335, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n",
" %140 : int = aten::mul(%138, %139) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %141 : int = aten::floordiv(%140, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %142 : int[] = prim::ListConstruct(%141, %11)\n",
" %336 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n",
" %143 : int[] = aten::size(%336) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %144 : int[] = aten::slice(%143, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %145 : int[] = aten::list(%144) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n",
" %146 : int[] = aten::add(%142, %145) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n",
" %size_prods.32 : int = aten::__getitem__(%146, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n",
" %148 : int = aten::len(%146) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %149 : int = aten::sub(%148, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %size_prods.34 : int = prim::Loop(%149, %17, %size_prods.32) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n",
" block0(%i.8 : int, %size_prods.36 : int):\n",
" %153 : int = aten::add(%i.8, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n",
" %154 : int = aten::__getitem__(%146, %153) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n",
" %size_prods.38 : int = aten::mul(%size_prods.36, %154) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n",
" -> (%17, %size_prods.38)\n",
" %156 : bool = aten::eq(%size_prods.34, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n",
" = prim::If(%156) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n",
" block0():\n",
" %157 : str = aten::format(%18, %146) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" = prim::RaiseException(%157) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %337 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n",
" %input.49 : Tensor = aten::group_norm(%337, %11, %136, %137, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n",
" %338 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.49)\n",
" %input.53 : Tensor = aten::gelu(%338) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n",
" %160 : Tensor = prim::GetAttr[name=\"weight\"](%38)\n",
" %161 : Tensor? = prim::GetAttr[name=\"bias\"](%38)\n",
" %339 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.53)\n",
" %input.57 : Tensor = aten::conv1d(%339, %160, %161, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n",
" %163 : bool = prim::GetAttr[name=\"training\"](%39)\n",
" %340 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.57)\n",
" %input.61 : Tensor = aten::dropout(%340, %20, %163) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n",
" %165 : Tensor = prim::GetAttr[name=\"weight\"](%40)\n",
" %166 : Tensor = prim::GetAttr[name=\"bias\"](%40)\n",
" %341 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n",
" %167 : int = aten::size(%341, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %342 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n",
" %168 : int = aten::size(%342, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n",
" %169 : int = aten::mul(%167, %168) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %170 : int = aten::floordiv(%169, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %171 : int[] = prim::ListConstruct(%170, %11)\n",
" %343 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n",
" %172 : int[] = aten::size(%343) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %173 : int[] = aten::slice(%172, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %174 : int[] = aten::list(%173) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n",
" %175 : int[] = aten::add(%171, %174) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n",
" %size_prods.40 : int = aten::__getitem__(%175, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n",
" %177 : int = aten::len(%175) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %178 : int = aten::sub(%177, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %size_prods.42 : int = prim::Loop(%178, %17, %size_prods.40) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n",
" block0(%i.10 : int, %size_prods.44 : int):\n",
" %182 : int = aten::add(%i.10, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n",
" %183 : int = aten::__getitem__(%175, %182) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n",
" %size_prods.46 : int = aten::mul(%size_prods.44, %183) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n",
" -> (%17, %size_prods.46)\n",
" %185 : bool = aten::eq(%size_prods.42, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n",
" = prim::If(%185) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n",
" block0():\n",
" %186 : str = aten::format(%18, %175) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" = prim::RaiseException(%186) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %344 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n",
" %input.65 : Tensor = aten::group_norm(%344, %11, %165, %166, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n",
" %345 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.65)\n",
" %input.69 : Tensor = aten::gelu(%345) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n",
" %189 : Tensor = prim::GetAttr[name=\"weight\"](%41)\n",
" %190 : Tensor? = prim::GetAttr[name=\"bias\"](%41)\n",
" %346 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.69)\n",
" %input.73 : Tensor = aten::conv1d(%346, %189, %190, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n",
" %192 : bool = prim::GetAttr[name=\"training\"](%42)\n",
" %347 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.73)\n",
" %input.77 : Tensor = aten::dropout(%347, %20, %192) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n",
" %194 : Tensor = prim::GetAttr[name=\"weight\"](%43)\n",
" %195 : Tensor = prim::GetAttr[name=\"bias\"](%43)\n",
" %348 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n",
" %196 : int = aten::size(%348, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %349 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n",
" %197 : int = aten::size(%349, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n",
" %198 : int = aten::mul(%196, %197) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %199 : int = aten::floordiv(%198, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %200 : int[] = prim::ListConstruct(%199, %11)\n",
" %350 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n",
" %201 : int[] = aten::size(%350) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %202 : int[] = aten::slice(%201, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %203 : int[] = aten::list(%202) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n",
" %204 : int[] = aten::add(%200, %203) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n",
" %size_prods.48 : int = aten::__getitem__(%204, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n",
" %206 : int = aten::len(%204) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %207 : int = aten::sub(%206, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %size_prods.50 : int = prim::Loop(%207, %17, %size_prods.48) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n",
" block0(%i.12 : int, %size_prods.52 : int):\n",
" %211 : int = aten::add(%i.12, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n",
" %212 : int = aten::__getitem__(%204, %211) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n",
" %size_prods.54 : int = aten::mul(%size_prods.52, %212) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n",
" -> (%17, %size_prods.54)\n",
" %214 : bool = aten::eq(%size_prods.50, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n",
" = prim::If(%214) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n",
" block0():\n",
" %215 : str = aten::format(%18, %204) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" = prim::RaiseException(%215) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %351 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n",
" %input.81 : Tensor = aten::group_norm(%351, %11, %194, %195, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n",
" %352 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.81)\n",
" %input.85 : Tensor = aten::gelu(%352) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n",
" %218 : Tensor = prim::GetAttr[name=\"weight\"](%44)\n",
" %219 : Tensor? = prim::GetAttr[name=\"bias\"](%44)\n",
" %353 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.85)\n",
" %input.89 : Tensor = aten::conv1d(%353, %218, %219, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n",
" %221 : bool = prim::GetAttr[name=\"training\"](%45)\n",
" %354 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.89)\n",
" %input.93 : Tensor = aten::dropout(%354, %20, %221) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n",
" %223 : Tensor = prim::GetAttr[name=\"weight\"](%46)\n",
" %224 : Tensor = prim::GetAttr[name=\"bias\"](%46)\n",
" %355 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n",
" %225 : int = aten::size(%355, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %356 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n",
" %226 : int = aten::size(%356, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n",
" %227 : int = aten::mul(%225, %226) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %228 : int = aten::floordiv(%227, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %229 : int[] = prim::ListConstruct(%228, %11)\n",
" %357 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n",
" %230 : int[] = aten::size(%357) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %231 : int[] = aten::slice(%230, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %232 : int[] = aten::list(%231) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n",
" %233 : int[] = aten::add(%229, %232) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n",
" %size_prods.56 : int = aten::__getitem__(%233, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n",
" %235 : int = aten::len(%233) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %236 : int = aten::sub(%235, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %size_prods.58 : int = prim::Loop(%236, %17, %size_prods.56) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n",
" block0(%i.14 : int, %size_prods.60 : int):\n",
" %240 : int = aten::add(%i.14, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n",
" %241 : int = aten::__getitem__(%233, %240) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n",
" %size_prods.62 : int = aten::mul(%size_prods.60, %241) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n",
" -> (%17, %size_prods.62)\n",
" %243 : bool = aten::eq(%size_prods.58, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n",
" = prim::If(%243) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n",
" block0():\n",
" %244 : str = aten::format(%18, %233) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" = prim::RaiseException(%244) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %358 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n",
" %input.97 : Tensor = aten::group_norm(%358, %11, %223, %224, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n",
" %359 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.97)\n",
" %input.101 : Tensor = aten::gelu(%359) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n",
" %247 : Tensor = prim::GetAttr[name=\"weight\"](%47)\n",
" %248 : Tensor? = prim::GetAttr[name=\"bias\"](%47)\n",
" %360 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.101)\n",
" %input.105 : Tensor = aten::conv1d(%360, %247, %248, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n",
" %250 : bool = prim::GetAttr[name=\"training\"](%48)\n",
" %361 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.105)\n",
" %input.109 : Tensor = aten::dropout(%361, %20, %250) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n",
" %252 : Tensor = prim::GetAttr[name=\"weight\"](%49)\n",
" %253 : Tensor = prim::GetAttr[name=\"bias\"](%49)\n",
" %362 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n",
" %254 : int = aten::size(%362, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %363 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n",
" %255 : int = aten::size(%363, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n",
" %256 : int = aten::mul(%254, %255) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %257 : int = aten::floordiv(%256, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n",
" %258 : int[] = prim::ListConstruct(%257, %11)\n",
" %364 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n",
" %259 : int[] = aten::size(%364) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %260 : int[] = aten::slice(%259, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n",
" %261 : int[] = aten::list(%260) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n",
" %262 : int[] = aten::add(%258, %261) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n",
" %size_prods.1 : int = aten::__getitem__(%262, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n",
" %264 : int = aten::len(%262) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %265 : int = aten::sub(%264, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n",
" %size_prods : int = prim::Loop(%265, %17, %size_prods.1) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n",
" block0(%i.1 : int, %size_prods.11 : int):\n",
" %269 : int = aten::add(%i.1, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n",
" %270 : int = aten::__getitem__(%262, %269) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n",
" %size_prods.5 : int = aten::mul(%size_prods.11, %270) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n",
" -> (%17, %size_prods.5)\n",
" %272 : bool = aten::eq(%size_prods, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n",
" = prim::If(%272) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n",
" block0():\n",
" %273 : str = aten::format(%18, %262) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n",
" = prim::RaiseException(%273) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %365 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n",
" %input.113 : Tensor = aten::group_norm(%365, %11, %252, %253, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n",
" %366 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.113)\n",
" %input.117 : Tensor = aten::gelu(%366) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n",
" %367 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.117)\n",
" %h.1 : Tensor = aten::transpose(%367, %11, %21) # <ipython-input-1-14e2ee0fb7fa>:69:12\n",
" %277 : __torch__.torch.nn.modules.rnn.LSTM = prim::GetAttr[name=\"lstm\"](%self)\n",
" %368 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n",
" %max_batch_size.1 : int = aten::size(%368, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:658:29\n",
" %369 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n",
" %279 : int = prim::dtype(%369)\n",
" %370 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n",
" %280 : Device = prim::device(%370)\n",
" %281 : int[] = prim::ListConstruct(%11, %max_batch_size.1, %11)\n",
" %h_zeros.1 : Tensor = aten::zeros(%281, %279, %22, %280, %22) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:665:22\n",
" %c_zeros.1 : Tensor = aten::zeros(%281, %279, %22, %280, %22) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:668:22\n",
" %371 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n",
" %284 : int = aten::dim(%371) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:11\n",
" %285 : bool = aten::ne(%284, %14) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:11\n",
" = prim::If(%285) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:8\n",
" block0():\n",
" %286 : str = aten::format(%13, %14, %284) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:202:16\n",
" = prim::RaiseException(%286) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:201:12\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %372 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n",
" %287 : int = aten::size(%372, %23) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:30\n",
" %288 : bool = aten::ne(%11, %287) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:11\n",
" = prim::If(%288) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:8\n",
" block0():\n",
" %289 : str = aten::format(%12, %11, %287) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:206:16\n",
" = prim::RaiseException(%289) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:205:12\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %expected_hidden_size.3 : (int, int, int) = prim::TupleConstruct(%11, %max_batch_size.1, %11)\n",
" %373 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%h_zeros.1)\n",
" %291 : int[] = aten::size(%373) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n",
" %292 : int[] = prim::ListConstruct(%11, %max_batch_size.1, %11)\n",
" %293 : bool = aten::ne(%291, %292) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n",
" = prim::If(%293) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:8\n",
" block0():\n",
" %294 : int[] = aten::list(%291) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:64\n",
" %295 : str = aten::format(%15, %expected_hidden_size.3, %294) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:31\n",
" = prim::RaiseException(%295) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:12\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %374 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%c_zeros.1)\n",
" %296 : int[] = aten::size(%374) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n",
" %297 : bool = aten::ne(%296, %292) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n",
" = prim::If(%297) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:8\n",
" block0():\n",
" %298 : int[] = aten::list(%296) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:64\n",
" %299 : str = aten::format(%16, %expected_hidden_size.3, %298) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:31\n",
" = prim::RaiseException(%299) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:12\n",
" -> ()\n",
" block1():\n",
" -> ()\n",
" %300 : Tensor[] = prim::GetAttr[name=\"_flat_weights\"](%277)\n",
" %301 : bool = prim::GetAttr[name=\"training\"](%277)\n",
" %375 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%h_zeros.1)\n",
" %376 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%c_zeros.1)\n",
" %302 : Tensor[] = prim::ListConstruct(%375, %376)\n",
" %377 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n",
" %303 : Tensor, %304 : Tensor, %305 : Tensor = aten::lstm(%377, %302, %300, %17, %11, %20, %301, %9, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:679:21\n",
" %306 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name=\"fc\"](%self)\n",
" %307 : Tensor = prim::GetAttr[name=\"weight\"](%306)\n",
" %308 : Tensor = prim::GetAttr[name=\"bias\"](%306)\n",
" %378 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[1, 2, 1], requires_grad=1, device=cpu)](%303)\n",
" %379 : Tensor = prim::profile[profiled_type=Float(30, 1, strides=[1, 1], requires_grad=1, device=cpu)](%307)\n",
" %380 : Tensor = prim::profile[profiled_type=Float(30, strides=[1], requires_grad=1, device=cpu)](%308)\n",
" %h.9 : Tensor = aten::linear(%378, %379, %380) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1847:11\n",
" %381 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%h.9)\n",
" %z.1 : Tensor = aten::log_softmax(%381, %23, %22) # <ipython-input-1-14e2ee0fb7fa>:72:15\n",
" %382 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%z.1)\n",
" %h.2 : Tensor = aten::transpose(%382, %10, %11) # <ipython-input-1-14e2ee0fb7fa>:85:12\n",
" %383 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[30, 60, 1], requires_grad=1, device=cpu)](%h.2)\n",
" %312 : Tensor = aten::ctc_loss(%383, %y.1, %xlen.1, %ylen.1, %10, %11, %9) # <ipython-input-1-14e2ee0fb7fa>:86:15\n",
" %384 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%z.1)\n",
" %313 : (Tensor, Tensor) = prim::TupleConstruct(%312, %384)\n",
" = prim::profile()\n",
" return (%313)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Jittable CTC model def\n",
"import torch\n",
"from torch import nn\n",
"\n",
"\n",
"class SpeechModel(nn.Module):\n",
" def __init__(self, n_vocab, n_hid, n_feat=1, dropout=0.5, act=nn.GELU):\n",
" super().__init__()\n",
" self.encoder = nn.Sequential(\n",
" # Wav2vec 2.0 style encoder https://arxiv.org/abs/2006.11477\n",
" nn.GroupNorm(1, n_feat),\n",
" \n",
" nn.Conv1d(n_feat, n_hid, 10, 5),\n",
" nn.Dropout(dropout),\n",
" nn.GroupNorm(1, n_hid),\n",
" act(),\n",
"\n",
" nn.Conv1d(n_hid, n_hid, 3, 2),\n",
" nn.Dropout(dropout),\n",
" nn.GroupNorm(1, n_hid),\n",
" act(),\n",
" \n",
" nn.Conv1d(n_hid, n_hid, 3, 2),\n",
" nn.Dropout(dropout),\n",
" nn.GroupNorm(1, n_hid),\n",
" act(),\n",
" \n",
" nn.Conv1d(n_hid, n_hid, 3, 2),\n",
" nn.Dropout(dropout),\n",
" nn.GroupNorm(1, n_hid),\n",
" act(),\n",
" \n",
" nn.Conv1d(n_hid, n_hid, 3, 2),\n",
" nn.Dropout(dropout),\n",
" nn.GroupNorm(1, n_hid),\n",
" act(),\n",
" \n",
" nn.Conv1d(n_hid, n_hid, 2, 2),\n",
" nn.Dropout(dropout),\n",
" nn.GroupNorm(1, n_hid),\n",
" act(),\n",
"\n",
" nn.Conv1d(n_hid, n_hid, 2, 2),\n",
" nn.Dropout(dropout),\n",
" nn.GroupNorm(1, n_hid),\n",
" act(),\n",
" )\n",
" self.lstm = nn.LSTM(n_hid, n_hid, 1, dropout=dropout, batch_first=True)\n",
" self.fc = nn.Linear(n_hid, n_vocab)\n",
"\n",
" @torch.jit.ignore # Cannot jit?\n",
" def convlen(self, xlen):\n",
" xlen = xlen.float()\n",
" for m in self.encoder:\n",
" if isinstance(m, nn.Conv1d):\n",
" xlen = torch.floor((xlen - m.kernel_size[0]) / m.stride[0] + 1)\n",
" return xlen.long()\n",
" \n",
" @torch.jit.export\n",
" def inference(self, x):\n",
" \"\"\"Transcribe one wav sequence to ids.\"\"\"\n",
" ids = self.logprob(x[None])[0].argmax(-1)\n",
" ids = torch.unique_consecutive(ids)\n",
" return ids[ids != 0]\n",
"\n",
" @torch.jit.export\n",
" def logprob(self, x):\n",
" \"\"\"Predicts log softmax distribution (batch, time, class).\"\"\"\n",
" h = self.encoder(x[:, None, :]).transpose(1, 2)\n",
" h, _ = self.lstm(h)\n",
" h = self.fc(h)\n",
" return h.log_softmax(-1)\n",
" \n",
" @torch.jit.export\n",
" def forward(self, x, xlen, y, ylen):\n",
" \"\"\"Computes CTC loss.\n",
" \n",
" Args:\n",
" x: input feature, float (batch, time1)\n",
" xlen: encoded lengths, long (batch)\n",
" y: target ids, long (batch, time2)\n",
" ylen: target lengths, long (batch)\n",
" \"\"\"\n",
" z = self.logprob(x) # (batch, time1, n_vocab)\n",
" h = z.transpose(0, 1) # (time1, batch, n_vocab)\n",
" return torch.ctc_loss(h, y, xlen, ylen), z\n",
"\n",
"\n",
"model = torch.jit.script(SpeechModel(30, 1))\n",
"x = torch.randn(2, 1000)\n",
"xlen = model.convlen(torch.tensor([1000, 900]).long())\n",
"y = torch.ones(2, 10).long()\n",
"ylen = torch.tensor([10, 9]).long()\n",
"print(\"#params:\\t\", sum(v.numel() for v in model.state_dict().values()))\n",
"print(\"out shape:\\t\", model.inference(x[0]).shape)\n",
"model.graph_for(x, xlen, y, ylen)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "38f14b44",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"848it [00:00, 528494.77it/s]\n",
"848it [00:00, 4886.44it/s]\n",
"100it [00:00, 412014.15it/s]\n",
"100it [00:00, 4852.50it/s]\n",
"130it [00:00, 388638.29it/s]\n",
"130it [00:00, 4795.81it/s]\n"
]
}
],
"source": [
"# Dataset\n",
"import os\n",
"import librosa\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"class ESPnetDataset(torch.utils.data.Dataset):\n",
" def __init__(self, root, sr=16000):\n",
" self.root = root\n",
" self.sr = sr\n",
" self.texts = {}\n",
" self.wavs = {}\n",
" self.ids = []\n",
" self.tokenizer = None\n",
" self.vocab = set()\n",
" with open(os.path.join(root, \"text\"), \"r\") as f:\n",
" for line in tqdm(f):\n",
" i, t = line.split(maxsplit=1)\n",
" self.ids.append(i)\n",
" t = t.strip()\n",
" self.texts[i] = t\n",
" for c in t:\n",
" self.vocab.add(c)\n",
" with open(os.path.join(root, \"wav.scp\"), \"r\") as f:\n",
" for line in tqdm(f):\n",
" s = line.split()\n",
" path = os.path.join(root, \"../..\", s[-2])\n",
" self.wavs[s[0]] = torch.as_tensor(librosa.load(path, sr=sr)[0])\n",
"\n",
" def __len__(self):\n",
" return len(self.ids)\n",
" \n",
" def __getitem__(self, idx):\n",
" k = self.ids[idx]\n",
" wav = self.wavs[k]\n",
" return wav, self.tokenizer.encode(self.texts[k])\n",
" \n",
"\n",
"class Tokenizer:\n",
" def __init__(self, *datasets):\n",
" vocabs = []\n",
" for d in datasets:\n",
" vocabs.append(d.vocab)\n",
" d.tokenizer = self\n",
" tokens = sorted(list(set.union(*vocabs)))\n",
" self.s2i = {'[blank]': 0, '[unk]': 1}\n",
" self.i2s = {0: '[blank]', 1: '[unk]'}\n",
" for i, token in enumerate(tokens, len(self.s2i)):\n",
" self.i2s[i] = token\n",
" self.s2i[token] = i\n",
" \n",
" def __len__(self):\n",
" return len(self.s2i)\n",
"\n",
" def encode(self, s):\n",
" return torch.tensor([self.s2i[c] for c in s])\n",
"\n",
" def decode(self, e):\n",
" return \"\".join(self.i2s[i.item()] for i in e)\n",
"\n",
" \n",
"trainset = ESPnetDataset('../egs/an4/asr1/data/train_nodev/')\n",
"devset = ESPnetDataset('../egs/an4/asr1/data/train_dev/')\n",
"testset = ESPnetDataset('../egs/an4/asr1/data/test')\n",
"tokenizer = Tokenizer(trainset, devset, testset)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "eb1be0e8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/cuda/__init__.py:83: UserWarning: \n",
" Found GPU%d %s which is of cuda capability %d.%d.\n",
" PyTorch no longer supports this GPU because it is too old.\n",
" The minimum cuda capability supported by this library is %d.%d.\n",
" \n",
" warnings.warn(old_gpu_warn.format(d, name, major, minor, min_arch // 10, min_arch % 10))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"max xlen: 102400\n",
"max ylen: 57\n",
"#vocab: 29\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at /opt/conda/conda-bld/pytorch_1623448278899/work/aten/src/ATen/native/cudnn/RNN.cpp:924.)\n",
" return forward_call(*input, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0\n",
"sec/epoch: 3.954852342605591\n",
"train loss: tensor(5.2251)\n",
"train cer: tensor(0.9882)\n",
"dev loss: tensor(3.0169)\n",
"dev cer: tensor(1.)\n",
"epoch: 5\n",
"sec/epoch: 3.640742063522339\n",
"train loss: tensor(2.9098)\n",
"train cer: tensor(1.)\n",
"dev loss: tensor(2.9442)\n",
"dev cer: tensor(1.)\n",
"epoch: 10\n",
"sec/epoch: 3.664271354675293\n",
"train loss: tensor(2.7860)\n",
"train cer: tensor(0.9860)\n",
"dev loss: tensor(2.7893)\n",
"dev cer: tensor(0.9835)\n",
"epoch: 15\n",
"sec/epoch: 3.6453144550323486\n",
"train loss: tensor(2.6531)\n",
"train cer: tensor(0.9509)\n",
"dev loss: tensor(2.6637)\n",
"dev cer: tensor(0.9455)\n",
"epoch: 20\n",
"sec/epoch: 3.6640310287475586\n",
"train loss: tensor(2.4639)\n",
"train cer: tensor(0.9403)\n",
"dev loss: tensor(2.5285)\n",
"dev cer: tensor(0.9422)\n",
"epoch: 25\n",
"sec/epoch: 3.6508846282958984\n",
"train loss: tensor(2.2603)\n",
"train cer: tensor(0.7138)\n",
"dev loss: tensor(2.3292)\n",
"dev cer: tensor(0.7413)\n",
"epoch: 30\n",
"sec/epoch: 3.669067859649658\n",
"train loss: tensor(2.0850)\n",
"train cer: tensor(0.6624)\n",
"dev loss: tensor(2.1549)\n",
"dev cer: tensor(0.6562)\n",
"epoch: 35\n",
"sec/epoch: 3.6770923137664795\n",
"train loss: tensor(1.8892)\n",
"train cer: tensor(0.5947)\n",
"dev loss: tensor(2.0249)\n",
"dev cer: tensor(0.6347)\n",
"epoch: 40\n",
"sec/epoch: 3.6583163738250732\n",
"train loss: tensor(1.6512)\n",
"train cer: tensor(0.4884)\n",
"dev loss: tensor(1.8241)\n",
"dev cer: tensor(0.5076)\n",
"epoch: 45\n",
"sec/epoch: 3.678729295730591\n",
"train loss: tensor(1.4269)\n",
"train cer: tensor(0.3909)\n",
"dev loss: tensor(1.7563)\n",
"dev cer: tensor(0.4751)\n",
"epoch: 50\n",
"sec/epoch: 3.658202886581421\n",
"train loss: tensor(1.2555)\n",
"train cer: tensor(0.3106)\n",
"dev loss: tensor(1.5976)\n",
"dev cer: tensor(0.3817)\n",
"epoch: 55\n",
"sec/epoch: 3.6857497692108154\n",
"train loss: tensor(1.0824)\n",
"train cer: tensor(0.2570)\n",
"dev loss: tensor(1.4412)\n",
"dev cer: tensor(0.3340)\n",
"epoch: 60\n",
"sec/epoch: 3.6613011360168457\n",
"train loss: tensor(0.9420)\n",
"train cer: tensor(0.2111)\n",
"dev loss: tensor(1.4324)\n",
"dev cer: tensor(0.3054)\n",
"epoch: 65\n",
"sec/epoch: 3.6761670112609863\n",
"train loss: tensor(0.8255)\n",
"train cer: tensor(0.1789)\n",
"dev loss: tensor(1.3962)\n",
"dev cer: tensor(0.2882)\n",
"epoch: 70\n",
"sec/epoch: 3.682213306427002\n",
"train loss: tensor(0.7396)\n",
"train cer: tensor(0.1536)\n",
"dev loss: tensor(1.3824)\n",
"dev cer: tensor(0.2519)\n",
"epoch: 75\n",
"sec/epoch: 3.653740882873535\n",
"train loss: tensor(0.6194)\n",
"train cer: tensor(0.1323)\n",
"dev loss: tensor(1.5055)\n",
"dev cer: tensor(0.2020)\n",
"epoch: 80\n",
"sec/epoch: 3.6865620613098145\n",
"train loss: tensor(0.5580)\n",
"train cer: tensor(0.1158)\n",
"dev loss: tensor(1.4025)\n",
"dev cer: tensor(0.2222)\n",
"epoch: 85\n",
"sec/epoch: 3.6628165245056152\n",
"train loss: tensor(0.4663)\n",
"train cer: tensor(0.0924)\n",
"dev loss: tensor(1.4166)\n",
"dev cer: tensor(0.1741)\n",
"epoch: 90\n",
"sec/epoch: 3.6864466667175293\n",
"train loss: tensor(0.4085)\n",
"train cer: tensor(0.0800)\n",
"dev loss: tensor(1.4796)\n",
"dev cer: tensor(0.1821)\n",
"epoch: 95\n",
"sec/epoch: 3.664776563644409\n",
"train loss: tensor(0.3649)\n",
"train cer: tensor(0.0709)\n",
"dev loss: tensor(1.4671)\n",
"dev cer: tensor(0.2182)\n",
"epoch: 100\n",
"sec/epoch: 3.677140235900879\n",
"train loss: tensor(0.3022)\n",
"train cer: tensor(0.0581)\n",
"dev loss: tensor(1.4241)\n",
"dev cer: tensor(0.1634)\n"
]
}
],
"source": [
"# Training\n",
"import time\n",
"import itertools\n",
"import editdistance\n",
"\n",
"# hyperparams\n",
"n_batch = 16\n",
"n_epoch = 100\n",
"n_hid = 128\n",
"lr = 0.001\n",
"dropout = 0.0\n",
"clip = 1.0\n",
"device = torch.device('cuda')\n",
"max_xlen = max(map(len, itertools.chain(trainset.wavs.values(), devset.wavs.values(), testset.wavs.values())))\n",
"max_ylen = max(map(len, itertools.chain(trainset.texts.values(), devset.texts.values(), testset.texts.values())))\n",
"print(\"max xlen:\", max_xlen)\n",
"print(\"max ylen:\", max_ylen)\n",
"print(\"#vocab:\", len(tokenizer))\n",
"\n",
"def collate(batch):\n",
" # print(batch)\n",
" xlen, ylen = [], []\n",
" xpad = torch.zeros(n_batch, max_xlen)\n",
" ypad = torch.zeros(n_batch, max_ylen)\n",
" for i, (x, y) in enumerate(batch):\n",
" xpad[i, :len(x)] = x\n",
" ypad[i, :len(y)] = y\n",
" xlen.append(len(x))\n",
" ylen.append(len(y))\n",
" return xpad, torch.tensor(xlen), ypad, torch.tensor(ylen)\n",
"\n",
"\n",
"def cer(pred, plen, ypad, ylen):\n",
" ids = pred.argmax(-1)\n",
" err = 0\n",
" n = 0\n",
" for p, pl, y, yl in zip(ids, plen, ypad, ylen):\n",
" p = torch.unique_consecutive(p[:pl])\n",
" p = p[p != 0] # filter blank\n",
" err += editdistance.eval(p, y[:yl])\n",
" n += yl\n",
" return err / n\n",
"\n",
"\n",
"train_loader = torch.utils.data.DataLoader(trainset, n_batch, collate_fn=collate, shuffle=True, drop_last=True)\n",
"dev_loader = torch.utils.data.DataLoader(devset, n_batch, collate_fn=collate, shuffle=False, drop_last=True)\n",
"test_loader = torch.utils.data.DataLoader(testset, n_batch, collate_fn=collate, shuffle=False, drop_last=True)\n",
"\n",
"model = torch.jit.script(SpeechModel(len(tokenizer), n_hid=n_hid, dropout=dropout))\n",
"model.to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
"for epoch in range(n_epoch + 1):\n",
" model.train()\n",
" loss_hist = []\n",
" cer_hist = []\n",
" start = time.time()\n",
" for xpad, xlen, ypad, ylen in train_loader:\n",
" optimizer.zero_grad()\n",
" hlen = model.convlen(xlen)\n",
" loss, pred = model(xpad.to(device), hlen.to(device),\n",
" ypad.to(device), ylen.to(device))\n",
" loss.backward()\n",
" nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
" optimizer.step()\n",
" cer_hist.append(cer(pred, hlen, ypad, ylen))\n",
" loss_hist.append(loss.float())\n",
" if epoch % 5 == 0:\n",
" print(\"epoch:\", epoch)\n",
" print(\"sec/epoch:\", time.time() - start)\n",
" print(\"train loss:\", torch.mean(torch.tensor(loss_hist)).float())\n",
" print(\"train cer:\", torch.mean(torch.tensor(cer_hist)).float())\n",
"\n",
" loss_hist = []\n",
" cer_hist = []\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for xpad, xlen, ypad, ylen in dev_loader:\n",
" hlen = model.convlen(xlen)\n",
" loss, pred = model(xpad.to(device), hlen.to(device),\n",
" ypad.to(device), ylen.to(device))\n",
" cer_hist.append(cer(pred, hlen, ypad, ylen))\n",
" loss_hist.append(loss.float())\n",
" if epoch % 5 == 0:\n",
" print(\"dev loss:\", torch.mean(torch.tensor(loss_hist)).float())\n",
" print(\"dev cer:\", torch.mean(torch.tensor(cer_hist)).float())\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0525926a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pred: \n",
"truth: YES\n",
"pred: FIFB SFTY SIX\n",
"truth: FIFTY ONE FIFTY SIX\n",
"pred: U FIVE TWO ONE SEVEN\n",
"truth: ONE FIVE TWO ONE SEVEN\n",
"pred: ETER FIVE\n",
"truth: ENTER FIVE\n",
"pred: UBOU J IU TWE\n",
"truth: RUBOUT J U I P THREE TWO EIGHT\n",
"pred: J LR E\n",
"truth: J E N N I F E R\n"
]
}
],
"source": [
"model.cpu()\n",
"model.eval()\n",
"for i, (xs, xlen, ys, ylen) in enumerate(dev_loader):\n",
" with torch.no_grad():\n",
" x = xs[0, :xlen[0]]\n",
" pred = model.inference(x)\n",
"\n",
" print(\"pred: \", tokenizer.decode(pred))\n",
" print(\"truth:\", tokenizer.decode(ys[0, :ylen[0]]))\n",
" if i == 5:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "742e96c7",
"metadata": {},
"outputs": [],
"source": [
"torch.jit.save(model, 'ctc.pt')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e2f818e1",
"metadata": {},
"outputs": [],
"source": [
"loaded = torch.jit.load('ctc.pt')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "360beaf7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pred: \n",
"truth: YES\n",
"pred: FIFB SFTY SIX\n",
"truth: FIFTY ONE FIFTY SIX\n",
"pred: U FIVE TWO ONE SEVEN\n",
"truth: ONE FIVE TWO ONE SEVEN\n",
"pred: ETER FIVE\n",
"truth: ENTER FIVE\n",
"pred: UBOU J IU TWE\n",
"truth: RUBOUT J U I P THREE TWO EIGHT\n",
"pred: J LR E\n",
"truth: J E N N I F E R\n"
]
}
],
"source": [
"for i, (xs, xlen, ys, ylen) in enumerate(dev_loader):\n",
" with torch.no_grad():\n",
" x = xs[0, :xlen[0]]\n",
" pred = loaded.inference(x)\n",
"\n",
" print(\"pred: \", tokenizer.decode(pred))\n",
" print(\"truth:\", tokenizer.decode(ys[0, :ylen[0]]))\n",
" if i == 5:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f2e5dcbd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thu Jun 24 02:31:24 2021 \r\n",
"+-----------------------------------------------------------------------------+\r\n",
"| NVIDIA-SMI 460.80 Driver Version: 460.80 CUDA Version: 11.2 |\r\n",
"|-------------------------------+----------------------+----------------------+\r\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n",
"| | | MIG M. |\r\n",
"|===============================+======================+======================|\r\n",
"| 0 GeForce GTX 760 Off | 00000000:01:00.0 N/A | N/A |\r\n",
"| 36% 48C P8 N/A / N/A | 121MiB / 1996MiB | N/A Default |\r\n",
"| | | N/A |\r\n",
"+-------------------------------+----------------------+----------------------+\r\n",
"| 1 GeForce GTX 108... Off | 00000000:02:00.0 Off | N/A |\r\n",
"| 0% 51C P8 10W / 250W | 2099MiB / 11178MiB | 0% Default |\r\n",
"| | | N/A |\r\n",
"+-------------------------------+----------------------+----------------------+\r\n",
" \r\n",
"+-----------------------------------------------------------------------------+\r\n",
"| Processes: |\r\n",
"| GPU GI CI PID Type Process name GPU Memory |\r\n",
"| ID ID Usage |\r\n",
"|=============================================================================|\r\n",
"| 1 N/A N/A 1456 G /usr/lib/xorg/Xorg 4MiB |\r\n",
"| 1 N/A N/A 25056 C ...net/tools/venv/bin/python 2091MiB |\r\n",
"+-----------------------------------------------------------------------------+\r\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c686df1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment