Last active August 25, 2022 12:59
This lfilter can propogate gradient to filter coefficients.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter as torch_lfilter
from torch.autograd import Function, gradcheck
class lfilter(Function):
def forward(ctx, x, a, b) -> torch.Tensor:
with torch.no_grad():
dummy = torch.zeros_like(a)
dummy[0] = 1
xh = torch_lfilter(x, a, dummy, False)
y = xh.view(-1, 1, xh.shape[-1])
y = F.pad(y, [b.numel() - 1, 0])
y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape)
ctx.save_for_backward(x, a, b, xh)
return y
def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor):
x, a, b, xh = ctx.saved_tensors
dx, da, db = (None,) * 3
batch = x.numel() // x.shape[-1]
with torch.no_grad():
if ctx.needs_input_grad[2]:
db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]),
dy.view(-1, 1, dy.shape[-1]),
groups=batch).sum((0, 1)).flip(0)
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]),
b.view(1, 1, -1)).view(*dy.shape)
dummy = torch.zeros_like(a)
if ctx.needs_input_grad[0]:
dummy[0] = 1
dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1)
if ctx.needs_input_grad[1]:
dummy[0] = -1
dxhda = torch_lfilter(xh, a, dummy, False)
da = F.conv1d(F.pad(dxhda.view(1, -1, dxhda.shape[-1]), [b.numel() - 1, 0]),
dxh.view(-1, 1, dxh.shape[-1]),
groups=batch).sum((0, 1)).flip(0)
return dx, da, db
if __name__ == '__main__':
x = torch.randn(4, 256, device='cuda', dtype=torch.double)
a = torch.rand(3, device='cuda', dtype=torch.double)
b = torch.rand(3, device='cuda', dtype=torch.double)
a.requires_grad = True
b.requires_grad = True
x.requires_grad = True
print(a, b)
with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof:
y = lfilter.apply(x, a, b)
loss = y.abs().sum()
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5))
print(gradcheck(lfilter.apply, (x, a, b), eps=1e-6))
"cells": [
"/home/ycy/miniconda3/envs/hrtf_notebooks/lib/python3.8/site-packages/torchaudio/backend/ UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to for the detail.\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.utils.benchmark as benchmark\n",
"from numpy.random import uniform\n",
"from differentiable_lfilter import lfilter\n",
"from itertools import product"
"## Define simple second-order IIR\n"
"class DTDFIICell(nn.Module):\n",
" def __init__(self):\n",
" super(DTDFIICell, self).__init__()\n",
" self.b0 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n",
" self.b1 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n",
" self.b2 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n",
" self.a0 = nn.Parameter(torch.FloatTensor([1]))\n",
" self.a1 = nn.Parameter(torch.FloatTensor([uniform(-0.5, 0.5)]))\n",
" self.a2 = nn.Parameter(torch.FloatTensor([uniform(-0.5, 0.5)]))\n",
" def _cat(self, vectors):\n",
" return[v_.unsqueeze(-1) for v_ in vectors], dim=-1)\n",
" def forward(self, input, v):\n",
" output = (input * self.b0 + v[:, 0]) / self.a0\n",
" v = self._cat([(input * self.b1 + v[:, 1] - output * self.a1), (input * self.b2 - output * self.a2)]) / self.a0\n",
" return output, v\n",
" def init_states(self, size):\n",
" v = torch.zeros(size, 2).to(next(self.parameters()).device)\n",
" return v\n",
"class DTDFII(nn.Module):\n",
" def __init__(self):\n",
" super(DTDFII, self).__init__()\n",
" self.cell = DTDFIICell()\n",
" def forward(self, input, initial_states=None):\n",
" batch_size = input.shape[0]\n",
" sequence_length = input.shape[1]\n",
" if initial_states is None:\n",
" states = input.new_zeros(batch_size, 2)\n",
" else:\n",
" states = initial_states\n",
" out_sequence = torch.zeros_like(input)\n",
" for s_idx in range(sequence_length):\n",
" out_sequence[:, s_idx], states = self.cell(input[:, s_idx].view(-1), states)\n",
" if initial_states is None:\n",
" return out_sequence\n",
" else:\n",
" return out_sequence, states"
"## Forward, CPU"
"Benchmarking on 2 threads\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48fa30>\n",
"simple IIR: Implemented using for-loop\n",
" 223.83 ms\n",
" 1 measurement, 10 runs , 2 threads\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48fcd0>\n",
"lfilter: Implemented using torchaudio.lfilter\n",
" 24.17 ms\n",
" 1 measurement, 10 runs , 2 threads\n"
"batch = 8\n",
"samples = 1024\n",
"x = torch.randn(batch, samples)\n",
"base = DTDFII()\n",
"a =[base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b =[base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"x.requires_grad = a.requires_grad = b.requires_grad = True\n",
"num_threads = torch.get_num_threads()\n",
"print(f'Benchmarking on {num_threads} threads')\n",
"t0 = benchmark.Timer(\n",
" stmt='m(x)',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" num_threads=num_threads,\n",
" label='simple IIR',\n",
" sub_label='Implemented using for-loop')\n",
"t1 = benchmark.Timer(\n",
" stmt='lfilter.apply(x, a, b)',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" num_threads=num_threads,\n",
" label='lfilter',\n",
" sub_label='Implemented using torchaudio.lfilter')\n",
"m0 = t0.timeit(10)\n",
"m1 = t1.timeit(10)\n",
"## Forward, GPU"
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48f7f0>\n",
"simple IIR: Implemented using for-loop\n",
" 392.72 ms\n",
" 1 measurement, 1 runs , 1 thread\n",
"Mean: 392.72 ms\n",
"Median: 392.72 ms\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48f910>\n",
"lfilter: Implemented using torchaudio.lfilter\n",
" Median: 44.26 ms\n",
" IQR: 1.34 ms (43.56 to 44.90)\n",
" 5 measurements, 1 runs per measurement, 1 thread\n",
"Mean: 44.26 ms\n",
"Median: 44.26 ms\n"
"batch = 8\n",
"samples = 1024\n",
"x = torch.randn(batch, samples, device='cuda')\n",
"base = DTDFII().to('cuda')\n",
"a =[base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b =[base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"x.requires_grad = a.requires_grad = b.requires_grad = True\n",
"t0 = benchmark.Timer(\n",
" stmt='m(x)',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" label='simple IIR',\n",
" sub_label='Implemented using for-loop')\n",
"t1 = benchmark.Timer(\n",
" stmt='lfilter.apply(x, a, b)',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" label='lfilter',\n",
" sub_label='Implemented using torchaudio.lfilter')\n",
"m0 = t0.blocked_autorange()\n",
"m1 = t1.blocked_autorange()\n",
"print(f\"Mean: {m0.mean * 1e3:6.2f} ms\")\n",
"print(f\"Median: {m0.median * 1e3:6.2f} ms\")\n",
"print(f\"Mean: {m1.mean * 1e3:6.2f} ms\")\n",
"print(f\"Median: {m1.median * 1e3:6.2f} ms\")"
"## Forward+backward, GPU"
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d38f5e0>\n",
"simple IIR: Implemented using for-loop\n",
" 486.54 ms\n",
" 1 measurement, 1 runs , 1 thread\n",
"Mean: 486.54 ms\n",
"Median: 486.54 ms\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d38f610>\n",
"lfilter: Implemented using torchaudio.lfilter\n",
" 35.69 ms\n",
" 1 measurement, 10 runs , 1 thread\n",
"Mean: 35.69 ms\n",
"Median: 35.69 ms\n"
"batch = 8\n",
"samples = 256\n",
"x = torch.randn(batch, samples, device='cuda')\n",
"base = DTDFII().to('cuda')\n",
"a =[base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b =[base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"x.requires_grad = a.requires_grad = b.requires_grad = True\n",
"dummy = base(x)\n",
"dummy = lfilter.apply(x, a, b)\n",
"assert x.grad is not None\n",
"assert a.grad is not None\n",
"assert b.grad is not None\n",
"t0 = benchmark.Timer(\n",
" stmt='y = m(x)\\ny.mean().backward()',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" label='simple IIR',\n",
" sub_label='Implemented using for-loop')\n",
"t1 = benchmark.Timer(\n",
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" label='lfilter',\n",
" sub_label='Implemented using torchaudio.lfilter')\n",
"m0 = t0.blocked_autorange()\n",
"m1 = t1.blocked_autorange()\n",
"print(f\"Mean: {m0.mean * 1e3:6.2f} ms\")\n",
"print(f\"Median: {m0.median * 1e3:6.2f} ms\")\n",
"print(f\"Mean: {m1.mean * 1e3:6.2f} ms\")\n",
"print(f\"Median: {m1.median * 1e3:6.2f} ms\")"
"# Different sizes, Forward+backward, GPU"
"[-------------- IIR filter -------------]\n",
" | for-loop | lfilter\n",
"1 threads: ------------------------------\n",
" [8, 16] | 30 | \u001b[34m\u001b[1m 6 \u001b[0m\u001b[0m\n",
" [8, 64] | \u001b[2m\u001b[91m 100 \u001b[0m\u001b[0m | 11 \n",
" [8, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 34 \u001b[0m\u001b[0m\n",
" [8, 1024] | \u001b[31m\u001b[1m 1287 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n",
" [16, 16] | 30 | \u001b[34m\u001b[1m 6 \u001b[0m\u001b[0m\n",
" [16, 64] | \u001b[2m\u001b[91m 81 \u001b[0m\u001b[0m | 11 \n",
" [16, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 34 \u001b[0m\u001b[0m\n",
" [16, 1024] | \u001b[31m\u001b[1m 1281 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n",
" [32, 16] | \u001b[92m\u001b[1m 20 \u001b[0m\u001b[0m | \u001b[92m\u001b[1m 6 \u001b[0m\u001b[0m\n",
" [32, 64] | \u001b[2m\u001b[91m 80 \u001b[0m\u001b[0m | \u001b[2m\u001b[91m 10 \u001b[0m\u001b[0m\n",
" [32, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n",
" [32, 1024] | \u001b[31m\u001b[1m 1533 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 100 \u001b[0m\u001b[0m\n",
" [64, 16] | 20 | \u001b[92m\u001b[1m 6 \u001b[0m\u001b[0m\n",
" [64, 64] | \u001b[2m\u001b[91m 80 \u001b[0m\u001b[0m | 11 \n",
" [64, 256] | \u001b[31m\u001b[1m 320 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n",
" [64, 1024] | \u001b[31m\u001b[1m 1316 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n",
"Times are in milliseconds (ms).\n",
"base = DTDFII().to('cuda')\n",
"a =[base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b =[base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"a.requires_grad = b.requires_grad = True\n",
"results = []\n",
"batches = [8, 16, 32, 64]\n",
"samples = [16, 64, 256, 1024]\n",
"for batch, n in product(batches, samples):\n",
" label = 'IIR filter'\n",
" sub_label = f'[{batch}, {n}]'\n",
" x = torch.randn(batch, n, device='cuda')\n",
" x.requires_grad = True\n",
" dummy = base(x)\n",
" dummy.sum().backward()\n",
" dummy = lfilter.apply(x, a, b)\n",
" dummy.sum().backward()\n",
" results.append(benchmark.Timer(\n",
" stmt='y = m(x)\\ny.mean().backward()',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description='for-loop',\n",
" ).blocked_autorange(min_run_time=1))\n",
" results.append(benchmark.Timer(\n",
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description='lfilter',\n",
" ).blocked_autorange(min_run_time=1))\n",
"compare = benchmark.Compare(results)\n",
"[----------------------- IIR filter -----------------------]\n",
" | for-loop | lfilter\n",
"1 threads: -------------------------------------------------\n",
" 442 x 169 | \u001b[31m\u001b[1m 200 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 25 \u001b[0m\u001b[0m\n",
" 36 x 244 | \u001b[31m\u001b[1m 300 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n",
" 26 x 848 | \u001b[31m\u001b[1m 1129 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 110 \u001b[0m\u001b[0m\n",
" 126 x 2653 | \u001b[31m\u001b[1m 3389 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 330 \u001b[0m\u001b[0m\n",
" 1201 x 755 (discontiguous) | \u001b[31m\u001b[1m 990 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 98 \u001b[0m\u001b[0m\n",
" 56 x 917 | \u001b[31m\u001b[1m 1166 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 120 \u001b[0m\u001b[0m\n",
" 324 x 463 | \u001b[31m\u001b[1m 600 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 60 \u001b[0m\u001b[0m\n",
" 97 x 639 | \u001b[31m\u001b[1m 801 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 81 \u001b[0m\u001b[0m\n",
" 56 x 4 (discontiguous) | \u001b[92m\u001b[1m 5 \u001b[0m\u001b[0m | \u001b[92m\u001b[1m 3 \u001b[0m\u001b[0m\n",
" 192 x 183 (discontiguous) | \u001b[31m\u001b[1m 254 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 30 \u001b[0m\u001b[0m\n",
"Times are in milliseconds (ms).\n",
"base = DTDFII().to('cuda')\n",
"a =[base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n",
"b =[base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n",
"a.requires_grad = b.requires_grad = True\n",
"results = []\n",
"example_fuzzer = benchmark.Fuzzer(\n",
" parameters = [\n",
" benchmark.FuzzedParameter('k0', minval=1, maxval=5000, distribution='loguniform'),\n",
" benchmark.FuzzedParameter('k1', minval=1, maxval=5000, distribution='loguniform'),\n",
" ],\n",
" tensors = [\n",
" benchmark.FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=1000000, probability_contiguous=0.6)\n",
" ],\n",
" seed=0,\n",
"for tensors, tensor_params, params in example_fuzzer.take(10):\n",
" sub_label=f\"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}\"\n",
" \n",
" x = tensors['x']\n",
" x = x.cuda()\n",
" x.requires_grad = True\n",
" dummy = base(x)\n",
" dummy.sum().backward()\n",
" dummy = lfilter.apply(x, a, b)\n",
" dummy.sum().backward()\n",
" \n",
" label = 'IIR filter'\n",
" results.append(benchmark.Timer(\n",
" stmt='y = m(x)\\ny.mean().backward()',\n",
" setup='',\n",
" globals={'x': x, 'm': base},\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description='for-loop',\n",
" ).blocked_autorange(min_run_time=1))\n",
" results.append(benchmark.Timer(\n",
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n",
" setup='from differentiable_lfilter import lfilter',\n",
" globals={'x': x, 'a': a, 'b': b},\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description='lfilter',\n",
" ).blocked_autorange(min_run_time=1))\n",
"compare = benchmark.Compare(results)\n",
This custom backward function have been added in newest torchaudio master branch.

