Skip to content

Instantly share code, notes, and snippets.

@yoyololicon
Last active August 25, 2022 12:59
Show Gist options
  • Save yoyololicon/f63f601d62187562070a61377cec9bf8 to your computer and use it in GitHub Desktop.
Save yoyololicon/f63f601d62187562070a61377cec9bf8 to your computer and use it in GitHub Desktop.
This lfilter can propogate gradient to filter coefficients.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter as torch_lfilter
from torch.autograd import Function, gradcheck
class lfilter(Function):
@staticmethod
def forward(ctx, x, a, b) -> torch.Tensor:
with torch.no_grad():
dummy = torch.zeros_like(a)
dummy[0] = 1
xh = torch_lfilter(x, a, dummy, False)
y = xh.view(-1, 1, xh.shape[-1])
y = F.pad(y, [b.numel() - 1, 0])
y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape)
ctx.save_for_backward(x, a, b, xh)
return y
@staticmethod
def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor):
x, a, b, xh = ctx.saved_tensors
dx, da, db = (None,) * 3
batch = x.numel() // x.shape[-1]
with torch.no_grad():
if ctx.needs_input_grad[2]:
db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]),
dy.view(-1, 1, dy.shape[-1]),
groups=batch).sum((0, 1)).flip(0)
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]),
b.view(1, 1, -1)).view(*dy.shape)
dummy = torch.zeros_like(a)
if ctx.needs_input_grad[0]:
dummy[0] = 1
dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1)
if ctx.needs_input_grad[1]:
dummy[0] = -1
dxhda = torch_lfilter(xh, a, dummy, False)
da = F.conv1d(F.pad(dxhda.view(1, -1, dxhda.shape[-1]), [b.numel() - 1, 0]),
dxh.view(-1, 1, dxh.shape[-1]),
groups=batch).sum((0, 1)).flip(0)
return dx, da, db
if __name__ == '__main__':
x = torch.randn(4, 256, device='cuda', dtype=torch.double)
a = torch.rand(3, device='cuda', dtype=torch.double)
b = torch.rand(3, device='cuda', dtype=torch.double)
a.div_(a[0].item())
a.requires_grad = True
b.requires_grad = True
x.requires_grad = True
print(a, b)
with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof:
y = lfilter.apply(x, a, b)
loss = y.abs().sum()
loss.backward()
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5))
print(gradcheck(lfilter.apply, (x, a, b), eps=1e-6))
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ycy/miniconda3/envs/hrtf_notebooks/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n",
" warnings.warn(\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.utils.benchmark as benchmark\n",
"from numpy.random import uniform\n",
"from differentiable_lfilter import lfilter\n",
"\n",
"from itertools import product"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define simple second-order IIR\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"#https://github.com/boris-kuz/differentiable_iir_filters/blob/master/differentiable_tdf2_model.py\n",
"class DTDFIICell(nn.Module):\n",
" def __init__(self):\n",
" super(DTDFIICell, self).__init__()\n",
" self.b0 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n",
" self.b1 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n",
" self.b2 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n",
" self.a0 = nn.Parameter(torch.FloatTensor([1]))\n",
" self.a1 = nn.Parameter(torch.FloatTensor([uniform(-0.5, 0.5)]))\n",
" self.a2 = nn.Parameter(torch.FloatTensor([uniform(-0.5, 0.5)]))\n",
"\n",
" def _cat(self, vectors):\n",
" return torch.cat([v_.unsqueeze(-1) for v_ in vectors], dim=-1)\n",
"\n",
" def forward(self, input, v):\n",
" output = (input * self.b0 + v[:, 0]) / self.a0\n",
" v = self._cat([(input * self.b1 + v[:, 1] - output * self.a1), (input * self.b2 - output * self.a2)]) / self.a0\n",
" return output, v\n",
"\n",
" def init_states(self, size):\n",
" v = torch.zeros(size, 2).to(next(self.parameters()).device)\n",
" return v\n",
"\n",
"\n",
"class DTDFII(nn.Module):\n",
" def __init__(self):\n",
" super(DTDFII, self).__init__()\n",
" self.cell = DTDFIICell()\n",
"\n",
" def forward(self, input, initial_states=None):\n",
" batch_size = input.shape[0]\n",
" sequence_length = input.shape[1]\n",
"\n",
" if initial_states is None:\n",
" states = input.new_zeros(batch_size, 2)\n",
" else:\n",
" states = initial_states\n",
"\n",
" out_sequence = torch.zeros_like(input)\n",
" for s_idx in range(sequence_length):\n",
" out_sequence[:, s_idx], states = self.cell(input[:, s_idx].view(-1), states)\n",
"\n",
" if initial_states is None:\n",
" return out_sequence\n",
" else:\n",
" return out_sequence, states"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Forward, CPU"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Benchmarking on 2 threads\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48fa30>\n",
"simple IIR: Implemented using for-loop\n",
" 223.83 ms\n",
" 1 measurement, 10 runs , 2 threads\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48fcd0>\n",
"lfilter: Implemented using torchaudio.lfilter\n",
" 24.17 ms\n",
" 1 measurement, 10 runs , 2 threads\n"
]
}
],
"source": [
"batch = 8\n",
"samples = 1024\n",
"\n",
"x = torch.randn(batch, samples)\n",
"base = DTDFII()\n",
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"x.requires_grad = a.requires_grad = b.requires_grad = True\n",
"\n",
"num_threads = torch.get_num_threads()\n",
"print(f'Benchmarking on {num_threads} threads')\n",
"\n",
"\n",
"t0 = benchmark.Timer(\n",
" stmt='m(x)',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" num_threads=num_threads,\n",
" label='simple IIR',\n",
" sub_label='Implemented using for-loop')\n",
"\n",
"t1 = benchmark.Timer(\n",
" stmt='lfilter.apply(x, a, b)',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" num_threads=num_threads,\n",
" label='lfilter',\n",
" sub_label='Implemented using torchaudio.lfilter')\n",
"\n",
"m0 = t0.timeit(10)\n",
"m1 = t1.timeit(10)\n",
"print(m0)\n",
"print(m1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Forward, GPU"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48f7f0>\n",
"simple IIR: Implemented using for-loop\n",
" 392.72 ms\n",
" 1 measurement, 1 runs , 1 thread\n",
"Mean: 392.72 ms\n",
"Median: 392.72 ms\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48f910>\n",
"lfilter: Implemented using torchaudio.lfilter\n",
" Median: 44.26 ms\n",
" IQR: 1.34 ms (43.56 to 44.90)\n",
" 5 measurements, 1 runs per measurement, 1 thread\n",
"Mean: 44.26 ms\n",
"Median: 44.26 ms\n"
]
}
],
"source": [
"batch = 8\n",
"samples = 1024\n",
"\n",
"x = torch.randn(batch, samples, device='cuda')\n",
"base = DTDFII().to('cuda')\n",
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"x.requires_grad = a.requires_grad = b.requires_grad = True\n",
"\n",
"\n",
"t0 = benchmark.Timer(\n",
" stmt='m(x)',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" label='simple IIR',\n",
" sub_label='Implemented using for-loop')\n",
"\n",
"t1 = benchmark.Timer(\n",
" stmt='lfilter.apply(x, a, b)',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" label='lfilter',\n",
" sub_label='Implemented using torchaudio.lfilter')\n",
"\n",
"m0 = t0.blocked_autorange()\n",
"m1 = t1.blocked_autorange()\n",
"print(m0)\n",
"print(f\"Mean: {m0.mean * 1e3:6.2f} ms\")\n",
"print(f\"Median: {m0.median * 1e3:6.2f} ms\")\n",
"print(m1)\n",
"print(f\"Mean: {m1.mean * 1e3:6.2f} ms\")\n",
"print(f\"Median: {m1.median * 1e3:6.2f} ms\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Forward+backward, GPU"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d38f5e0>\n",
"simple IIR: Implemented using for-loop\n",
" 486.54 ms\n",
" 1 measurement, 1 runs , 1 thread\n",
"Mean: 486.54 ms\n",
"Median: 486.54 ms\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d38f610>\n",
"lfilter: Implemented using torchaudio.lfilter\n",
" 35.69 ms\n",
" 1 measurement, 10 runs , 1 thread\n",
"Mean: 35.69 ms\n",
"Median: 35.69 ms\n"
]
}
],
"source": [
"batch = 8\n",
"samples = 256\n",
"\n",
"x = torch.randn(batch, samples, device='cuda')\n",
"base = DTDFII().to('cuda')\n",
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"x.requires_grad = a.requires_grad = b.requires_grad = True\n",
"\n",
"dummy = base(x)\n",
"dummy.sum().backward()\n",
"\n",
"dummy = lfilter.apply(x, a, b)\n",
"dummy.sum().backward()\n",
"\n",
"assert x.grad is not None\n",
"assert a.grad is not None\n",
"assert b.grad is not None\n",
"\n",
"t0 = benchmark.Timer(\n",
" stmt='y = m(x)\\ny.mean().backward()',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" label='simple IIR',\n",
" sub_label='Implemented using for-loop')\n",
"\n",
"t1 = benchmark.Timer(\n",
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" label='lfilter',\n",
" sub_label='Implemented using torchaudio.lfilter')\n",
"\n",
"m0 = t0.blocked_autorange()\n",
"m1 = t1.blocked_autorange()\n",
"print(m0)\n",
"print(f\"Mean: {m0.mean * 1e3:6.2f} ms\")\n",
"print(f\"Median: {m0.median * 1e3:6.2f} ms\")\n",
"print(m1)\n",
"print(f\"Mean: {m1.mean * 1e3:6.2f} ms\")\n",
"print(f\"Median: {m1.median * 1e3:6.2f} ms\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Different sizes, Forward+backward, GPU"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-------------- IIR filter -------------]\n",
" | for-loop | lfilter\n",
"1 threads: ------------------------------\n",
" [8, 16] | 30 | \u001b[34m\u001b[1m 6 \u001b[0m\u001b[0m\n",
" [8, 64] | \u001b[2m\u001b[91m 100 \u001b[0m\u001b[0m | 11 \n",
" [8, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 34 \u001b[0m\u001b[0m\n",
" [8, 1024] | \u001b[31m\u001b[1m 1287 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n",
" [16, 16] | 30 | \u001b[34m\u001b[1m 6 \u001b[0m\u001b[0m\n",
" [16, 64] | \u001b[2m\u001b[91m 81 \u001b[0m\u001b[0m | 11 \n",
" [16, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 34 \u001b[0m\u001b[0m\n",
" [16, 1024] | \u001b[31m\u001b[1m 1281 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n",
" [32, 16] | \u001b[92m\u001b[1m 20 \u001b[0m\u001b[0m | \u001b[92m\u001b[1m 6 \u001b[0m\u001b[0m\n",
" [32, 64] | \u001b[2m\u001b[91m 80 \u001b[0m\u001b[0m | \u001b[2m\u001b[91m 10 \u001b[0m\u001b[0m\n",
" [32, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n",
" [32, 1024] | \u001b[31m\u001b[1m 1533 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 100 \u001b[0m\u001b[0m\n",
" [64, 16] | 20 | \u001b[92m\u001b[1m 6 \u001b[0m\u001b[0m\n",
" [64, 64] | \u001b[2m\u001b[91m 80 \u001b[0m\u001b[0m | 11 \n",
" [64, 256] | \u001b[31m\u001b[1m 320 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n",
" [64, 1024] | \u001b[31m\u001b[1m 1316 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n",
"\n",
"Times are in milliseconds (ms).\n",
"\n"
]
}
],
"source": [
"base = DTDFII().to('cuda')\n",
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"a.requires_grad = b.requires_grad = True\n",
"\n",
"results = []\n",
"batches = [8, 16, 32, 64]\n",
"samples = [16, 64, 256, 1024]\n",
"for batch, n in product(batches, samples):\n",
" label = 'IIR filter'\n",
" sub_label = f'[{batch}, {n}]'\n",
" x = torch.randn(batch, n, device='cuda')\n",
" x.requires_grad = True\n",
"\n",
" dummy = base(x)\n",
" dummy.sum().backward()\n",
"\n",
" dummy = lfilter.apply(x, a, b)\n",
" dummy.sum().backward()\n",
"\n",
" results.append(benchmark.Timer(\n",
" stmt='y = m(x)\\ny.mean().backward()',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description='for-loop',\n",
" ).blocked_autorange(min_run_time=1))\n",
" results.append(benchmark.Timer(\n",
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description='lfilter',\n",
" ).blocked_autorange(min_run_time=1))\n",
"\n",
"compare = benchmark.Compare(results)\n",
"compare.trim_significant_figures()\n",
"compare.colorize()\n",
"compare.print()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[----------------------- IIR filter -----------------------]\n",
" | for-loop | lfilter\n",
"1 threads: -------------------------------------------------\n",
" 442 x 169 | \u001b[31m\u001b[1m 200 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 25 \u001b[0m\u001b[0m\n",
" 36 x 244 | \u001b[31m\u001b[1m 300 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n",
" 26 x 848 | \u001b[31m\u001b[1m 1129 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 110 \u001b[0m\u001b[0m\n",
" 126 x 2653 | \u001b[31m\u001b[1m 3389 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 330 \u001b[0m\u001b[0m\n",
" 1201 x 755 (discontiguous) | \u001b[31m\u001b[1m 990 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 98 \u001b[0m\u001b[0m\n",
" 56 x 917 | \u001b[31m\u001b[1m 1166 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 120 \u001b[0m\u001b[0m\n",
" 324 x 463 | \u001b[31m\u001b[1m 600 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 60 \u001b[0m\u001b[0m\n",
" 97 x 639 | \u001b[31m\u001b[1m 801 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 81 \u001b[0m\u001b[0m\n",
" 56 x 4 (discontiguous) | \u001b[92m\u001b[1m 5 \u001b[0m\u001b[0m | \u001b[92m\u001b[1m 3 \u001b[0m\u001b[0m\n",
" 192 x 183 (discontiguous) | \u001b[31m\u001b[1m 254 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 30 \u001b[0m\u001b[0m\n",
"\n",
"Times are in milliseconds (ms).\n",
"\n"
]
}
],
"source": [
"base = DTDFII().to('cuda')\n",
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"a.requires_grad = b.requires_grad = True\n",
"\n",
"results = []\n",
"\n",
"example_fuzzer = benchmark.Fuzzer(\n",
" parameters = [\n",
" benchmark.FuzzedParameter('k0', minval=1, maxval=5000, distribution='loguniform'),\n",
" benchmark.FuzzedParameter('k1', minval=1, maxval=5000, distribution='loguniform'),\n",
" ],\n",
" tensors = [\n",
" benchmark.FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=1000000, probability_contiguous=0.6)\n",
" ],\n",
" seed=0,\n",
")\n",
"\n",
"for tensors, tensor_params, params in example_fuzzer.take(10):\n",
" sub_label=f\"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}\"\n",
" \n",
" x = tensors['x']\n",
" x = x.cuda()\n",
" x.requires_grad = True\n",
"\n",
" dummy = base(x)\n",
" dummy.sum().backward()\n",
"\n",
" dummy = lfilter.apply(x, a, b)\n",
" dummy.sum().backward()\n",
" \n",
" label = 'IIR filter'\n",
"\n",
" results.append(benchmark.Timer(\n",
" stmt='y = m(x)\\ny.mean().backward()',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description='for-loop',\n",
" ).blocked_autorange(min_run_time=1))\n",
" results.append(benchmark.Timer(\n",
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description='lfilter',\n",
" ).blocked_autorange(min_run_time=1))\n",
"\n",
"compare = benchmark.Compare(results)\n",
"compare.trim_significant_figures()\n",
"compare.colorize()\n",
"compare.print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
@yoyololicon
Copy link
Author

This custom backward function have been added in newest torchaudio master branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment