Skip to content

Instantly share code, notes, and snippets.

@usamec
Created February 6, 2020 14:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save usamec/af21be7b83e6b1a3f38c26136af811f3 to your computer and use it in GitHub Desktop.
Save usamec/af21be7b83e6b1a3f38c26136af811f3 to your computer and use it in GitHub Desktop.
GRU tuning
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch import Tensor\n",
"from typing import List"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"inputs = [torch.randn(64, 256, 256), torch.randn(128, 128, 256), torch.randn(256, 64, 256), torch.randn(512, 32, 256)]\n",
"inputs = [x.cuda() for x in inputs]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"gru_cudnn = nn.GRU(256, 256).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.93 ms ± 36.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"out, _ = gru_cudnn(inputs[1])\n",
"s = out.sum()\n",
"s.sum().detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8.91 ms ± 796 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"out, _ = gru_cudnn(inputs[1])\n",
"s = out.sum()\n",
"s.backward()\n",
"gru_cudnn.bias_ih_l0.grad.sum().detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class GRU(nn.Module):\n",
" def __init__(self, ks=256):\n",
" super(GRU, self).__init__()\n",
" \n",
" self.weight_ih = nn.Parameter(torch.randn(ks, ks * 3))\n",
" self.weight_hh = nn.Parameter(torch.randn(ks, ks * 3))\n",
" self.bias_ih = nn.Parameter(torch.randn(3 * ks))\n",
" self.bias_hh = nn.Parameter(torch.randn(3 * ks))\n",
" \n",
" self.ks = ks\n",
" \n",
" def forward(self, x: Tensor):\n",
" xparts = torch.unbind(x, 0)\n",
" \n",
" outs: List[Tensor] = []\n",
" last = torch.zeros_like((x[0]))\n",
" for i in range(len(xparts)):\n",
" part = xparts[i]\n",
" ip = torch.mm(part, self.weight_ih) + self.bias_ih\n",
" hp = torch.mm(last, self.weight_hh) + self.bias_hh\n",
" i_r, i_z, i_n = ip.chunk(3, 1)\n",
" h_r, h_z, h_n = hp.chunk(3, 1)\n",
" r = torch.sigmoid(i_r + h_r)\n",
" z = torch.sigmoid(i_z + h_z)\n",
" \n",
" n = torch.tanh(i_n + r * h_n)\n",
" last = (1 - z) * n + z * last\n",
" \n",
" outs += [last]\n",
" \n",
" return torch.stack(outs)\n",
" \n",
"gru_raw = GRU(256).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"68.9 ms ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"out = gru_raw(inputs[1])\n",
"s = out.sum()\n",
"s.sum().detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"161 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"out = gru_raw(inputs[1])\n",
"s = out.sum()\n",
"s.backward()\n",
"gru_raw.bias_ih.grad.sum().detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"gru_jit = torch.jit.script(GRU(256).cuda())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"22.7 ms ± 23.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"out = gru_jit(inputs[1])\n",
"s = out.sum()\n",
"s.sum().detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"58.1 ms ± 682 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"out = gru_jit(inputs[1])\n",
"s = out.sum()\n",
"s.backward()\n",
"gru_jit.bias_ih.grad.sum().detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment