Skip to content

Instantly share code, notes, and snippets.

@yoyololicon
Last active August 25, 2022 12:59
Show Gist options
  • Save yoyololicon/f63f601d62187562070a61377cec9bf8 to your computer and use it in GitHub Desktop.
Save yoyololicon/f63f601d62187562070a61377cec9bf8 to your computer and use it in GitHub Desktop.
This lfilter can propogate gradient to filter coefficients.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter as torch_lfilter
from torch.autograd import Function, gradcheck
class lfilter(Function):
@staticmethod
def forward(ctx, x, a, b) -> torch.Tensor:
with torch.no_grad():
dummy = torch.zeros_like(a)
dummy[0] = 1
xh = torch_lfilter(x, a, dummy, False)
y = xh.view(-1, 1, xh.shape[-1])
y = F.pad(y, [b.numel() - 1, 0])
y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape)
ctx.save_for_backward(x, a, b, xh)
return y
@staticmethod
def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor):
x, a, b, xh = ctx.saved_tensors
dx, da, db = (None,) * 3
batch = x.numel() // x.shape[-1]
with torch.no_grad():
if ctx.needs_input_grad[2]:
db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]),
dy.view(-1, 1, dy.shape[-1]),
groups=batch).sum((0, 1)).flip(0)
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]),
b.view(1, 1, -1)).view(*dy.shape)
dummy = torch.zeros_like(a)
if ctx.needs_input_grad[0]:
dummy[0] = 1
dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1)
if ctx.needs_input_grad[1]:
dummy[0] = -1
dxhda = torch_lfilter(xh, a, dummy, False)
da = F.conv1d(F.pad(dxhda.view(1, -1, dxhda.shape[-1]), [b.numel() - 1, 0]),
dxh.view(-1, 1, dxh.shape[-1]),
groups=batch).sum((0, 1)).flip(0)
return dx, da, db
if __name__ == '__main__':
x = torch.randn(4, 256, device='cuda', dtype=torch.double)
a = torch.rand(3, device='cuda', dtype=torch.double)
b = torch.rand(3, device='cuda', dtype=torch.double)
a.div_(a[0].item())
a.requires_grad = True
b.requires_grad = True
x.requires_grad = True
print(a, b)
with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof:
y = lfilter.apply(x, a, b)
loss = y.abs().sum()
loss.backward()
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5))
print(gradcheck(lfilter.apply, (x, a, b), eps=1e-6))
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@yoyololicon
Copy link
Author

This custom backward function have been added in newest torchaudio master branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment